import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from typing import Callable def train( model: nn.Module, training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], epochs: int = 100, learning_rate: float = 1e-3, weight_decay: float = 1e-8 ) -> tuple[list[float], list[float]]: optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) avg_training_losses = [] avg_validation_losses = [] for epoch in range(epochs): model.train() total_loss = [] for data in tqdm(training_loader): optimizer.zero_grad() x_hat = model(data) loss = loss_fn(x_hat, data) loss.backward() optimizer.step() total_loss.append(loss.item()) avg_training_losses.append(sum(total_loss) / len(total_loss)) with torch.no_grad(): losses = [] for data in validation_loader: x_hat = model(data) loss = loss_fn(x_hat, data) losses.append(loss.item()) avg_loss = sum(losses) / len(losses) avg_validation_losses.append(avg_loss) tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}") return avg_training_losses, avg_validation_losses