code cleanup

This commit is contained in:
Robin Meersman 2025-11-30 19:21:29 +01:00
parent ea9cf12db0
commit 73d1742cbd
44 changed files with 6 additions and 2835 deletions

26
trainers/FullTrainer.py Normal file
View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

57
trainers/OptunaTrainer.py Normal file
View file

@ -0,0 +1,57 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")

3
trainers/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .OptunaTrainer import OptunaTrainer
from .FullTrainer import FullTrainer
from .trainer import Trainer

59
trainers/train.py Normal file
View file

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad()
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
return avg_training_losses, avg_validation_losses

22
trainers/trainer.py Normal file
View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass