Merge branch 'main' into portability

This commit is contained in:
Tibo De Peuter 2025-12-05 10:58:07 +01:00
commit 013afe832f
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
44 changed files with 5 additions and 2834 deletions

26
trainers/FullTrainer.py Normal file
View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

62
trainers/OptunaTrainer.py Normal file
View file

@ -0,0 +1,62 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")

3
trainers/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .OptunaTrainer import OptunaTrainer
from .FullTrainer import FullTrainer
from .trainer import Trainer

59
trainers/train.py Normal file
View file

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad()
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
return avg_training_losses, avg_validation_losses

22
trainers/trainer.py Normal file
View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass