from typing import Callable import optuna import optuna.trial as tr import torch from torch import nn as nn from torch.utils.data import DataLoader from .trainer import Trainer from ..models.cnn import CNNPredictor from .train import train def create_model(trial: tr.Trial, vocab_size: int = 256): hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True) embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True) return CNNPredictor( vocab_size=vocab_size, hidden_dim=hidden_dim, embed_dim=embedding_dim, ) def objective_function( trial: tr.Trial, training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], device: str ): model = create_model(trial).to(device) _, validation_loss = train(model, training_loader, validation_loader, loss_fn) return min(validation_loss) class OptunaTrainer(Trainer): def execute( self, model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str ) -> None: study = optuna.create_study(study_name="CNN network", direction="minimize") study.optimize( lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), n_trials=20 ) best_params = study.best_trial.params best_model = CNNPredictor( **best_params ) torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")