feat: transformer fixed
This commit is contained in:
parent
f97c7c9130
commit
d12bb25d0a
5 changed files with 65 additions and 44 deletions
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, Transformer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
|
|
@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module):
|
|||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case Transformer.__class__:
|
||||
case ByteTransformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
|
|
|
|||
Reference in a new issue