from typing import Callable import optuna import optuna.trial as tr import torch from torch import nn as nn from torch.utils.data import DataLoader from .trainer import Trainer from ..models.cnn import CNNPredictor from .train import train def create_model(trial: tr.Trial, vocab_size: int = 256): hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True) embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True) return CNNPredictor( vocab_size=vocab_size, hidden_dim=hidden_dim, embed_dim=embedding_dim, ) def objective_function( trial: tr.Trial, training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], device: str ): model = create_model(trial).to(device) _, validation_loss = train(model, training_loader, validation_loader, loss_fn) return min(validation_loss) class OptunaTrainer(Trainer): def __init__(self, n_trials: int | None = None): super().__init__() self.n_trials = n_trials if n_trials is not None else 20 print(f"Creating Optuna trainer(n_trials = {self.n_trials})") def execute( self, model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str ) -> None: study = optuna.create_study(study_name="CNN network", direction="minimize") study.optimize( lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), n_trials=self.n_trials ) best_params = study.best_trial.params best_model = CNNPredictor( **best_params ) torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")