fix: Transformer Optuna params
Co-authored-by: Robin Meersman <echteenrobin@gmail.com>
This commit is contained in:
parent
ef50d6321e
commit
59722acf76
1 changed files with 2 additions and 6 deletions
|
|
@ -1,8 +1,5 @@
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import optuna
|
import optuna
|
||||||
import optuna.trial as tr
|
import optuna.trial as tr
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
@ -20,10 +17,9 @@ def create_model(trial: tr.Trial, model: nn.Module):
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
)
|
)
|
||||||
case Transformer.__class__:
|
case Transformer.__class__:
|
||||||
nhead = trial.suggest_int("nhead", 2, 8, log=True)
|
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||||
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
|
|
||||||
return model(
|
return model(
|
||||||
d_model=d_model,
|
d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead),
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||||
|
|
|
||||||
Reference in a new issue