feat: transformer fixed
This commit is contained in:
parent
f97c7c9130
commit
d12bb25d0a
5 changed files with 65 additions and 44 deletions
|
|
@ -1,9 +1,8 @@
|
|||
from .Model import Model
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import Transformer
|
||||
|
||||
from .transformer import ByteTransformer
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': Transformer
|
||||
'transformer': ByteTransformer
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .transformer import Transformer
|
||||
from .transformer import ByteTransformer
|
||||
|
|
@ -1,10 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch import Tensor, arange
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class Transformer(nn.Transformer):
|
||||
class LearnedPositionalEncoding(Model):
|
||||
def __init__(self, max_len, d_model):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_len, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [seq, batch, d_model]
|
||||
seq_len = x.size(0)
|
||||
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
||||
return x + self.pos_emb(positions) # broadcast over batch
|
||||
|
||||
class ByteTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
|
|
@ -14,9 +27,17 @@ class Transformer(nn.Transformer):
|
|||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05
|
||||
layer_norm_eps=1e-05,
|
||||
max_len=128
|
||||
):
|
||||
super().__init__(
|
||||
super().__init__()
|
||||
self.src_embedding = nn.Embedding(256, d_model)
|
||||
self.tgt_embedding = nn.Embedding(256, d_model)
|
||||
|
||||
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
|
|
@ -28,34 +49,22 @@ class Transformer(nn.Transformer):
|
|||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None
|
||||
dtype=None,
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(d_model, 256)
|
||||
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
src_is_causal: Optional[bool] = None,
|
||||
tgt_is_causal: Optional[bool] = None,
|
||||
memory_is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
return super().forward(
|
||||
src,
|
||||
tgt,
|
||||
src_mask,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
src_key_padding_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
src_is_causal,
|
||||
tgt_is_causal,
|
||||
memory_is_causal,
|
||||
)
|
||||
src_embeds = self.src_embedding(src)
|
||||
tgt_embeds = self.tgt_embedding(tgt)
|
||||
|
||||
src_pos = self.src_pos(src_embeds)
|
||||
tgt_pos = self.tgt_pos(tgt_embeds)
|
||||
|
||||
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
||||
|
|
|
|||
Reference in a new issue