feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -0,0 +1,63 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from ..model.cnn import CNNPredictor
from train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, "models/final_model.pt")