This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../src/trainers/OptunaTrainer.py
2025-12-10 14:46:10 +01:00

72 lines
2.6 KiB
Python

import optuna
import optuna.trial as tr
from torch import nn
from torch.utils.data import DataLoader
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer
def create_model(trial: tr.Trial, model: nn.Module):
match model.__class__:
case CNNPredictor.__class__:
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case ByteTransformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model(
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
model: Model,
device: str
):
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int,
device: str
) -> nn.Module:
study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = model(
**best_params
)
return best_model