feat: transformer fixed

This commit is contained in:
Robin Meersman 2025-12-10 14:46:10 +01:00
parent f97c7c9130
commit d12bb25d0a
5 changed files with 65 additions and 44 deletions

View file

@ -1,9 +1,8 @@
from .Model import Model from .Model import Model
from .cnn import CNNPredictor from .cnn import CNNPredictor
from .transformer import Transformer from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = { model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor, 'cnn': CNNPredictor,
'transformer': Transformer 'transformer': ByteTransformer
} }

View file

@ -1 +1 @@
from .transformer import Transformer from .transformer import ByteTransformer

View file

@ -1,10 +1,23 @@
from typing import Optional from typing import Optional
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor, arange
from src.models import Model
class Transformer(nn.Transformer): class LearnedPositionalEncoding(Model):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [seq, batch, d_model]
seq_len = x.size(0)
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
return x + self.pos_emb(positions) # broadcast over batch
class ByteTransformer(nn.Module):
def __init__( def __init__(
self, self,
d_model=512, d_model=512,
@ -14,9 +27,17 @@ class Transformer(nn.Transformer):
dim_feedforward=2048, dim_feedforward=2048,
dropout=0.1, dropout=0.1,
activation="relu", activation="relu",
layer_norm_eps=1e-05 layer_norm_eps=1e-05,
max_len=128
): ):
super().__init__( super().__init__()
self.src_embedding = nn.Embedding(256, d_model)
self.tgt_embedding = nn.Embedding(256, d_model)
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
self.transformer = nn.Transformer(
d_model=d_model, d_model=d_model,
nhead=nhead, nhead=nhead,
num_encoder_layers=num_encoder_layers, num_encoder_layers=num_encoder_layers,
@ -28,34 +49,22 @@ class Transformer(nn.Transformer):
batch_first=False, batch_first=False,
norm_first=False, norm_first=False,
device=None, device=None,
dtype=None dtype=None,
) )
self.output_proj = nn.Linear(d_model, 256)
self.loss_function = nn.CrossEntropyLoss() self.loss_function = nn.CrossEntropyLoss()
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
tgt: Tensor, tgt: Tensor,
src_mask: Optional[Tensor] = None,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: Optional[bool] = None,
tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False,
) -> Tensor: ) -> Tensor:
return super().forward( src_embeds = self.src_embedding(src)
src, tgt_embeds = self.tgt_embedding(tgt)
tgt,
src_mask, src_pos = self.src_pos(src_embeds)
tgt_mask, tgt_pos = self.tgt_pos(tgt_embeds)
memory_mask,
src_key_padding_mask, return self.output_proj(self.transformer(src_pos, tgt_pos))
tgt_key_padding_mask,
memory_key_padding_mask,
src_is_causal,
tgt_is_causal,
memory_is_causal,
)

View file

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from .train import train from .train import train
from .trainer import Trainer from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer from ..models import Model, CNNPredictor, ByteTransformer
def create_model(trial: tr.Trial, model: nn.Module): def create_model(trial: tr.Trial, model: nn.Module):
@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module):
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256, vocab_size=256,
) )
case Transformer.__class__: case ByteTransformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model( return model(

View file

@ -1,15 +1,31 @@
from typing import Callable
import torch import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from typing import Callable
from ..models import ByteTransformer, Model
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
if isinstance(model, ByteTransformer):
tgt_in = torch.cat([
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
x[:, :-1]
], dim=1)
logits = model(x, tgt_in)
# only consider the last time step of the model where the full context
# is available
return logits[:, -1, :]
return model(x)
def train( def train(
model: nn.Module, model: Model,
training_loader: DataLoader, training_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loss_fn: Callable,
epochs: int = 100, epochs: int = 100,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
weight_decay: float = 1e-8, weight_decay: float = 1e-8,
@ -34,11 +50,8 @@ def train(
y = y.long().to(device) y = y.long().to(device)
optimizer.zero_grad() optimizer.zero_grad()
if issubclass(type(model), nn.Transformer): logits = _forward(model, x, device)
tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1)
logits = model(x, tgt)
else:
logits = model(x) # (B, 256)
loss = loss_fn(logits, y) loss = loss_fn(logits, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -55,7 +68,7 @@ def train(
x = x.long().to(device) x = x.long().to(device)
y = y.long().to(device) y = y.long().to(device)
logits = model(x) logits = _forward(model, x, device)
loss = loss_fn(logits, y) loss = loss_fn(logits, y)
losses.append(loss.item()) losses.append(loss.item())