diff --git a/src/models/__init__.py b/src/models/__init__.py index e7e29b4..e329dbc 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,9 +1,8 @@ from .Model import Model from .cnn import CNNPredictor -from .transformer import Transformer - +from .transformer import ByteTransformer model_called: dict[str, type[Model]] = { 'cnn': CNNPredictor, - 'transformer': Transformer + 'transformer': ByteTransformer } diff --git a/src/models/transformer/__init__.py b/src/models/transformer/__init__.py index 6ff14d5..9817800 100644 --- a/src/models/transformer/__init__.py +++ b/src/models/transformer/__init__.py @@ -1 +1 @@ -from .transformer import Transformer \ No newline at end of file +from .transformer import ByteTransformer \ No newline at end of file diff --git a/src/models/transformer/transformer.py b/src/models/transformer/transformer.py index 774ae7c..f85e60d 100644 --- a/src/models/transformer/transformer.py +++ b/src/models/transformer/transformer.py @@ -1,10 +1,23 @@ from typing import Optional import torch.nn as nn -from torch import Tensor +from torch import Tensor, arange + +from src.models import Model -class Transformer(nn.Transformer): +class LearnedPositionalEncoding(Model): + def __init__(self, max_len, d_model): + super().__init__() + self.pos_emb = nn.Embedding(max_len, d_model) + + def forward(self, x): + # x: [seq, batch, d_model] + seq_len = x.size(0) + positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1] + return x + self.pos_emb(positions) # broadcast over batch + +class ByteTransformer(nn.Module): def __init__( self, d_model=512, @@ -14,9 +27,17 @@ class Transformer(nn.Transformer): dim_feedforward=2048, dropout=0.1, activation="relu", - layer_norm_eps=1e-05 + layer_norm_eps=1e-05, + max_len=128 ): - super().__init__( + super().__init__() + self.src_embedding = nn.Embedding(256, d_model) + self.tgt_embedding = nn.Embedding(256, d_model) + + self.src_pos = LearnedPositionalEncoding(max_len, d_model) + self.tgt_pos = LearnedPositionalEncoding(max_len, d_model) + + self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, @@ -28,34 +49,22 @@ class Transformer(nn.Transformer): batch_first=False, norm_first=False, device=None, - dtype=None + dtype=None, ) + + self.output_proj = nn.Linear(d_model, 256) + self.loss_function = nn.CrossEntropyLoss() def forward( self, src: Tensor, tgt: Tensor, - src_mask: Optional[Tensor] = None, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - src_is_causal: Optional[bool] = None, - tgt_is_causal: Optional[bool] = None, - memory_is_causal: bool = False, ) -> Tensor: - return super().forward( - src, - tgt, - src_mask, - tgt_mask, - memory_mask, - src_key_padding_mask, - tgt_key_padding_mask, - memory_key_padding_mask, - src_is_causal, - tgt_is_causal, - memory_is_causal, - ) + src_embeds = self.src_embedding(src) + tgt_embeds = self.tgt_embedding(tgt) + + src_pos = self.src_pos(src_embeds) + tgt_pos = self.tgt_pos(tgt_embeds) + + return self.output_proj(self.transformer(src_pos, tgt_pos)) diff --git a/src/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py index 8b8e602..e40aeeb 100644 --- a/src/trainers/OptunaTrainer.py +++ b/src/trainers/OptunaTrainer.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from .train import train from .trainer import Trainer -from ..models import Model, CNNPredictor, Transformer +from ..models import Model, CNNPredictor, ByteTransformer def create_model(trial: tr.Trial, model: nn.Module): @@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module): embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), vocab_size=256, ) - case Transformer.__class__: + case ByteTransformer.__class__: nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) return model( diff --git a/src/trainers/train.py b/src/trainers/train.py index 5fb9270..d26e7de 100644 --- a/src/trainers/train.py +++ b/src/trainers/train.py @@ -1,15 +1,31 @@ +from typing import Callable + import torch -import torch.nn as nn from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from typing import Callable + +from ..models import ByteTransformer, Model + + +def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor: + if isinstance(model, ByteTransformer): + tgt_in = torch.cat([ + torch.zeros(x.shape[0], 1, device=device, dtype=torch.long), + x[:, :-1] + ], dim=1) + logits = model(x, tgt_in) + + # only consider the last time step of the model where the full context + # is available + return logits[:, -1, :] + return model(x) def train( - model: nn.Module, + model: Model, training_loader: DataLoader, validation_loader: DataLoader, - loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + loss_fn: Callable, epochs: int = 100, learning_rate: float = 1e-3, weight_decay: float = 1e-8, @@ -17,7 +33,7 @@ def train( ) -> tuple[list[float], list[float]]: model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) - + avg_training_losses = [] avg_validation_losses = [] @@ -34,11 +50,8 @@ def train( y = y.long().to(device) optimizer.zero_grad() - if issubclass(type(model), nn.Transformer): - tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1) - logits = model(x, tgt) - else: - logits = model(x) # (B, 256) + logits = _forward(model, x, device) + loss = loss_fn(logits, y) loss.backward() optimizer.step() @@ -55,7 +68,7 @@ def train( x = x.long().to(device) y = y.long().to(device) - logits = model(x) + logits = _forward(model, x, device) loss = loss_fn(logits, y) losses.append(loss.item())