From 59722acf76044c462a7955dcd4ad1cd900d1f131 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Sun, 7 Dec 2025 20:45:29 +0100 Subject: [PATCH] fix: Transformer Optuna params Co-authored-by: Robin Meersman --- src/trainers/OptunaTrainer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py index e9a7417..de273e9 100644 --- a/src/trainers/OptunaTrainer.py +++ b/src/trainers/OptunaTrainer.py @@ -1,8 +1,5 @@ -from typing import Callable - import optuna import optuna.trial as tr -import torch from torch import nn from torch.utils.data import DataLoader @@ -20,10 +17,9 @@ def create_model(trial: tr.Trial, model: nn.Module): vocab_size=256, ) case Transformer.__class__: - nhead = trial.suggest_int("nhead", 2, 8, log=True) - d_model = trial.suggest_int("d_model", 64, 512, step=nhead) + nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 return model( - d_model=d_model, + d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead), nhead=nhead, num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),