59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from tqdm import tqdm
|
|
from typing import Callable
|
|
|
|
|
|
def train(
|
|
model: nn.Module,
|
|
training_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
epochs: int = 100,
|
|
learning_rate: float = 1e-3,
|
|
weight_decay: float = 1e-8,
|
|
device="cuda"
|
|
) -> tuple[list[float], list[float]]:
|
|
|
|
model.to(device)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
|
|
|
avg_training_losses = []
|
|
avg_validation_losses = []
|
|
|
|
for epoch in range(epochs):
|
|
|
|
model.train()
|
|
total_loss = []
|
|
|
|
for x, y in tqdm(training_loader):
|
|
x = x.long().to(device) # important for Embedding
|
|
y = y.long().to(device) # must be (B,) for CE
|
|
|
|
optimizer.zero_grad()
|
|
logits = model(x) # (B, 256)
|
|
loss = loss_fn(logits, y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss.append(loss.item())
|
|
|
|
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
|
|
|
# ----- validation -----
|
|
model.eval()
|
|
with torch.no_grad():
|
|
losses = []
|
|
for x, y in validation_loader:
|
|
x = x.long().to(device)
|
|
y = y.long().to(device)
|
|
|
|
logits = model(x)
|
|
loss = loss_fn(logits, y)
|
|
losses.append(loss.item())
|
|
|
|
avg_loss = sum(losses) / len(losses)
|
|
avg_validation_losses.append(avg_loss)
|
|
|
|
return avg_training_losses, avg_validation_losses
|