feat: transformer fixed

This commit is contained in:
Robin Meersman 2025-12-10 14:46:10 +01:00
parent f97c7c9130
commit d12bb25d0a
5 changed files with 65 additions and 44 deletions

View file

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer
from ..models import Model, CNNPredictor, ByteTransformer
def create_model(trial: tr.Trial, model: nn.Module):
@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module):
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case Transformer.__class__:
case ByteTransformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model(