feat: added simple check to distinguish other model types from transformer
This commit is contained in:
parent
8311eabd4d
commit
f97c7c9130
3 changed files with 62 additions and 57 deletions
|
|
@ -18,8 +18,9 @@ def create_model(trial: tr.Trial, model: nn.Module):
|
|||
)
|
||||
case Transformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
d_model=nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead),
|
||||
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
|
|
|
|||
|
|
@ -27,11 +27,18 @@ def train(
|
|||
total_loss = []
|
||||
|
||||
for x, y in tqdm(training_loader):
|
||||
x = x.long().to(device) # important for Embedding
|
||||
y = y.long().to(device) # must be (B,) for CE
|
||||
# size (B, 128)
|
||||
x = x.long().to(device)
|
||||
|
||||
# size (B)
|
||||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = model(x) # (B, 256)
|
||||
if issubclass(type(model), nn.Transformer):
|
||||
tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1)
|
||||
logits = model(x, tgt)
|
||||
else:
|
||||
logits = model(x) # (B, 256)
|
||||
loss = loss_fn(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
|
|||
Reference in a new issue