62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
from typing import Callable
|
|
|
|
import optuna
|
|
import optuna.trial as tr
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from .trainer import Trainer
|
|
from ..models.cnn import CNNPredictor
|
|
from .train import train
|
|
|
|
|
|
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
|
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
|
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
|
|
|
return CNNPredictor(
|
|
vocab_size=vocab_size,
|
|
hidden_dim=hidden_dim,
|
|
embed_dim=embedding_dim,
|
|
)
|
|
|
|
|
|
def objective_function(
|
|
trial: tr.Trial,
|
|
training_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
device: str
|
|
):
|
|
model = create_model(trial).to(device)
|
|
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
|
return min(validation_loss)
|
|
|
|
|
|
class OptunaTrainer(Trainer):
|
|
def __init__(self, n_trials: int | None = None):
|
|
super().__init__()
|
|
self.n_trials = n_trials if n_trials is not None else 20
|
|
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
|
|
|
def execute(
|
|
self,
|
|
model: nn.Module | None,
|
|
train_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
n_epochs: int,
|
|
device: str
|
|
) -> None:
|
|
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
|
study.optimize(
|
|
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
|
n_trials=self.n_trials
|
|
)
|
|
|
|
best_params = study.best_trial.params
|
|
best_model = CNNPredictor(
|
|
**best_params
|
|
)
|
|
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|