feat: transformer

This commit is contained in:
Robin Meersman 2025-11-30 22:50:19 +01:00
parent d74558557c
commit c9758e1e40
3 changed files with 63 additions and 1 deletions

View file

@ -1 +1,2 @@
from .cnn import CNNPredictor
from .transformer import Transformer

View file

@ -0,0 +1 @@
from .transformer import Transformer

View file

@ -0,0 +1,60 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
class Transformer(nn.Transformer):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05
):
super().__init__(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None
)
def forward(
self,
src: Tensor,
tgt: Tensor,
src_mask: Optional[Tensor] = None,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: Optional[bool] = None,
tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False,
) -> Tensor:
return super().forward(
src,
tgt,
src_mask,
tgt_mask,
memory_mask,
src_key_padding_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
src_is_causal,
tgt_is_causal,
memory_is_causal,
)