feat: added simple check to distinguish other model types from transformer

This commit is contained in:
Robin Meersman 2025-12-09 22:24:14 +01:00
parent 8311eabd4d
commit f97c7c9130
3 changed files with 62 additions and 57 deletions

View file

@ -18,8 +18,9 @@ def create_model(trial: tr.Trial, model: nn.Module):
)
case Transformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model(
d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead),
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),

View file

@ -27,11 +27,18 @@ def train(
total_loss = []
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
# size (B, 128)
x = x.long().to(device)
# size (B)
y = y.long().to(device)
optimizer.zero_grad()
logits = model(x) # (B, 256)
if issubclass(type(model), nn.Transformer):
tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1)
logits = model(x, tgt)
else:
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()