fix: Transformer Optuna params

Co-authored-by: Robin Meersman <echteenrobin@gmail.com>
This commit is contained in:
Tibo De Peuter 2025-12-07 20:45:29 +01:00
parent ef50d6321e
commit 59722acf76
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -1,8 +1,5 @@
from typing import Callable
import optuna import optuna
import optuna.trial as tr import optuna.trial as tr
import torch
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -20,10 +17,9 @@ def create_model(trial: tr.Trial, model: nn.Module):
vocab_size=256, vocab_size=256,
) )
case Transformer.__class__: case Transformer.__class__:
nhead = trial.suggest_int("nhead", 2, 8, log=True) nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
return model( return model(
d_model=d_model, d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead),
nhead=nhead, nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),