feat: transformer fixed
This commit is contained in:
parent
f97c7c9130
commit
d12bb25d0a
5 changed files with 65 additions and 44 deletions
|
|
@ -1,9 +1,8 @@
|
||||||
from .Model import Model
|
from .Model import Model
|
||||||
from .cnn import CNNPredictor
|
from .cnn import CNNPredictor
|
||||||
from .transformer import Transformer
|
from .transformer import ByteTransformer
|
||||||
|
|
||||||
|
|
||||||
model_called: dict[str, type[Model]] = {
|
model_called: dict[str, type[Model]] = {
|
||||||
'cnn': CNNPredictor,
|
'cnn': CNNPredictor,
|
||||||
'transformer': Transformer
|
'transformer': ByteTransformer
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .transformer import Transformer
|
from .transformer import ByteTransformer
|
||||||
|
|
@ -1,10 +1,23 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor, arange
|
||||||
|
|
||||||
|
from src.models import Model
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Transformer):
|
class LearnedPositionalEncoding(Model):
|
||||||
|
def __init__(self, max_len, d_model):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_emb = nn.Embedding(max_len, d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [seq, batch, d_model]
|
||||||
|
seq_len = x.size(0)
|
||||||
|
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
||||||
|
return x + self.pos_emb(positions) # broadcast over batch
|
||||||
|
|
||||||
|
class ByteTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model=512,
|
d_model=512,
|
||||||
|
|
@ -14,9 +27,17 @@ class Transformer(nn.Transformer):
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
activation="relu",
|
activation="relu",
|
||||||
layer_norm_eps=1e-05
|
layer_norm_eps=1e-05,
|
||||||
|
max_len=128
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__()
|
||||||
|
self.src_embedding = nn.Embedding(256, d_model)
|
||||||
|
self.tgt_embedding = nn.Embedding(256, d_model)
|
||||||
|
|
||||||
|
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||||
|
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||||
|
|
||||||
|
self.transformer = nn.Transformer(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
|
@ -28,34 +49,22 @@ class Transformer(nn.Transformer):
|
||||||
batch_first=False,
|
batch_first=False,
|
||||||
norm_first=False,
|
norm_first=False,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None
|
dtype=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.output_proj = nn.Linear(d_model, 256)
|
||||||
|
|
||||||
self.loss_function = nn.CrossEntropyLoss()
|
self.loss_function = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
tgt: Tensor,
|
tgt: Tensor,
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
src_is_causal: Optional[bool] = None,
|
|
||||||
tgt_is_causal: Optional[bool] = None,
|
|
||||||
memory_is_causal: bool = False,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
return super().forward(
|
src_embeds = self.src_embedding(src)
|
||||||
src,
|
tgt_embeds = self.tgt_embedding(tgt)
|
||||||
tgt,
|
|
||||||
src_mask,
|
src_pos = self.src_pos(src_embeds)
|
||||||
tgt_mask,
|
tgt_pos = self.tgt_pos(tgt_embeds)
|
||||||
memory_mask,
|
|
||||||
src_key_padding_mask,
|
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
||||||
tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask,
|
|
||||||
src_is_causal,
|
|
||||||
tgt_is_causal,
|
|
||||||
memory_is_causal,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .train import train
|
from .train import train
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from ..models import Model, CNNPredictor, Transformer
|
from ..models import Model, CNNPredictor, ByteTransformer
|
||||||
|
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, model: nn.Module):
|
def create_model(trial: tr.Trial, model: nn.Module):
|
||||||
|
|
@ -16,7 +16,7 @@ def create_model(trial: tr.Trial, model: nn.Module):
|
||||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
)
|
)
|
||||||
case Transformer.__class__:
|
case ByteTransformer.__class__:
|
||||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||||
return model(
|
return model(
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,31 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable
|
|
||||||
|
from ..models import ByteTransformer, Model
|
||||||
|
|
||||||
|
|
||||||
|
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||||
|
if isinstance(model, ByteTransformer):
|
||||||
|
tgt_in = torch.cat([
|
||||||
|
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
||||||
|
x[:, :-1]
|
||||||
|
], dim=1)
|
||||||
|
logits = model(x, tgt_in)
|
||||||
|
|
||||||
|
# only consider the last time step of the model where the full context
|
||||||
|
# is available
|
||||||
|
return logits[:, -1, :]
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
model: nn.Module,
|
model: Model,
|
||||||
training_loader: DataLoader,
|
training_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
loss_fn: Callable,
|
||||||
epochs: int = 100,
|
epochs: int = 100,
|
||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
weight_decay: float = 1e-8,
|
weight_decay: float = 1e-8,
|
||||||
|
|
@ -34,11 +50,8 @@ def train(
|
||||||
y = y.long().to(device)
|
y = y.long().to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if issubclass(type(model), nn.Transformer):
|
logits = _forward(model, x, device)
|
||||||
tgt = torch.cat([x[:, 1:], y.unsqueeze(1)], dim=1)
|
|
||||||
logits = model(x, tgt)
|
|
||||||
else:
|
|
||||||
logits = model(x) # (B, 256)
|
|
||||||
loss = loss_fn(logits, y)
|
loss = loss_fn(logits, y)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -55,7 +68,7 @@ def train(
|
||||||
x = x.long().to(device)
|
x = x.long().to(device)
|
||||||
y = y.long().to(device)
|
y = y.long().to(device)
|
||||||
|
|
||||||
logits = model(x)
|
logits = _forward(model, x, device)
|
||||||
loss = loss_fn(logits, y)
|
loss = loss_fn(logits, y)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue