feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

View file

@ -3,60 +3,73 @@ from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch import nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def create_model(trial: tr.Trial, model: nn.Module):
match model.__class__:
case CNNPredictor.__class__:
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case Transformer.__class__:
nhead = trial.suggest_int("nhead", 2, 8, log=True)
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
return model(
d_model=d_model,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
model: Model,
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
) -> nn.Module:
study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = CNNPredictor(
best_model = model(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
return best_model