import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from typing import Callable def train( model: nn.Module, training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], epochs: int = 100, learning_rate: float = 1e-3, weight_decay: float = 1e-8, device="cuda" ) -> tuple[list[float], list[float]]: model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) avg_training_losses = [] avg_validation_losses = [] for epoch in range(epochs): model.train() total_loss = [] for x, y in tqdm(training_loader): x = x.long().to(device) # important for Embedding y = y.long().to(device) # must be (B,) for CE optimizer.zero_grad() logits = model(x) # (B, 256) loss = loss_fn(logits, y) loss.backward() optimizer.step() total_loss.append(loss.item()) avg_training_losses.append(sum(total_loss) / len(total_loss)) # ----- validation ----- model.eval() with torch.no_grad(): losses = [] for x, y in validation_loader: x = x.long().to(device) y = y.long().to(device) logits = model(x) loss = loss_fn(logits, y) losses.append(loss.item()) avg_loss = sum(losses) / len(losses) avg_validation_losses.append(avg_loss) tqdm.write(f"epoch: {epoch + 1}, avg val loss = {avg_loss:.4f}") return avg_training_losses, avg_validation_losses