This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../trainers/train.py
Robin Meersman 73d1742cbd code cleanup
2025-11-30 19:21:29 +01:00

59 lines
1.7 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad()
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
return avg_training_losses, avg_validation_losses