This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../trainers/OptunaTrainer.py
Robin Meersman 73d1742cbd code cleanup
2025-11-30 19:21:29 +01:00

57 lines
1.7 KiB
Python

from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")