feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

View file

@ -1,26 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch import nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from .trainer import Trainer
from ..models import Model
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
) -> nn.Module:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
print_losses(train_loss, val_loss)
return model

View file

@ -3,60 +3,73 @@ from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch import nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def create_model(trial: tr.Trial, model: nn.Module):
match model.__class__:
case CNNPredictor.__class__:
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case Transformer.__class__:
nhead = trial.suggest_int("nhead", 2, 8, log=True)
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
return model(
d_model=d_model,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
model: Model,
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
) -> nn.Module:
study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = CNNPredictor(
best_model = model(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
return best_model

View file

@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
@ -15,8 +13,7 @@ class Trainer(ABC):
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass
) -> nn.Module:
pass