feat: added simple check to distinguish other model types from transformer

This commit is contained in:
Robin Meersman 2025-12-09 22:24:14 +01:00
parent 8311eabd4d
commit f97c7c9130
3 changed files with 62 additions and 57 deletions

View file

@ -18,8 +18,9 @@ def create_model(trial: tr.Trial, model: nn.Module):
)
case Transformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model(
d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead),
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),