import optuna import optuna.trial as tr from torch import nn from torch.utils.data import DataLoader from .train import train from .trainer import Trainer from ..models import Model, CNNPredictor, ByteTransformer def create_model(trial: tr.Trial, model: nn.Module): match model.__class__: case CNNPredictor.__class__: return model( hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True), embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), vocab_size=256, ) case ByteTransformer.__class__: nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) return model( d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors nhead=nhead, num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True), dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True), activation=trial.suggest_categorical("activation", ["relu", "gelu"]), layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True), ) return None def objective_function( trial: tr.Trial, training_loader: DataLoader, validation_loader: DataLoader, model: Model, device: str ): model = create_model(trial, model).to(device) _, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device) return min(validation_loss) class OptunaTrainer(Trainer): def __init__(self, n_trials: int | None = None): super().__init__() self.n_trials = n_trials if n_trials else 20 print(f"Creating Optuna trainer(n_trials = {self.n_trials})") def execute( self, model: Model, train_loader: DataLoader, validation_loader: DataLoader, n_epochs: int, device: str ) -> nn.Module: study = optuna.create_study(direction="minimize") study.optimize( lambda trial: objective_function(trial, train_loader, validation_loader, model, device), n_trials=self.n_trials ) best_params = study.best_trial.params best_model = model( **best_params ) return best_model