feat: transformer fixed
This commit is contained in:
parent
f97c7c9130
commit
d12bb25d0a
5 changed files with 65 additions and 44 deletions
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, Transformer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
|
|
@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module):
|
|||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case Transformer.__class__:
|
||||
case ByteTransformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
|
|
|
|||
|
|
@ -1,15 +1,31 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import Callable
|
||||
|
||||
from ..models import ByteTransformer, Model
|
||||
|
||||
|
||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||
if isinstance(model, ByteTransformer):
|
||||
tgt_in = torch.cat([
|
||||
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
||||
x[:, :-1]
|
||||
], dim=1)
|
||||
logits = model(x, tgt_in)
|
||||
|
||||
# only consider the last time step of the model where the full context
|
||||
# is available
|
||||
return logits[:, -1, :]
|
||||
return model(x)
|
||||
|
||||
|
||||
def train(
|
||||
model: nn.Module,
|
||||
model: Model,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
loss_fn: Callable,
|
||||
epochs: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
|
|
@ -17,7 +33,7 @@ def train(
|
|||
) -> tuple[list[float], list[float]]:
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
|
|
@ -34,11 +50,8 @@ def train(
|
|||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
if issubclass(type(model), nn.Transformer):
|
||||
tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1)
|
||||
logits = model(x, tgt)
|
||||
else:
|
||||
logits = model(x) # (B, 256)
|
||||
logits = _forward(model, x, device)
|
||||
|
||||
loss = loss_fn(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
@ -55,7 +68,7 @@ def train(
|
|||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
logits = model(x)
|
||||
logits = _forward(model, x, device)
|
||||
loss = loss_fn(logits, y)
|
||||
losses.append(loss.item())
|
||||
|
||||
|
|
|
|||
Reference in a new issue