From c9758e1e406ece067550d88d2d462b25483ce12d Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Sun, 30 Nov 2025 22:50:19 +0100 Subject: [PATCH] feat: transformer --- models/__init__.py | 3 +- models/transformer/__init__.py | 1 + models/transformer/transformer.py | 60 +++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 models/transformer/__init__.py create mode 100644 models/transformer/transformer.py diff --git a/models/__init__.py b/models/__init__.py index 153551d..42e6d4c 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,2 @@ -from .cnn import CNNPredictor \ No newline at end of file +from .cnn import CNNPredictor +from .transformer import Transformer \ No newline at end of file diff --git a/models/transformer/__init__.py b/models/transformer/__init__.py new file mode 100644 index 0000000..6ff14d5 --- /dev/null +++ b/models/transformer/__init__.py @@ -0,0 +1 @@ +from .transformer import Transformer \ No newline at end of file diff --git a/models/transformer/transformer.py b/models/transformer/transformer.py new file mode 100644 index 0000000..63032eb --- /dev/null +++ b/models/transformer/transformer.py @@ -0,0 +1,60 @@ +from typing import Optional + +import torch.nn as nn +from torch import Tensor + + +class Transformer(nn.Transformer): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + layer_norm_eps=1e-05 + ): + super().__init__( + d_model=d_model, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=False, + norm_first=False, + device=None, + dtype=None + ) + + def forward( + self, + src: Tensor, + tgt: Tensor, + src_mask: Optional[Tensor] = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + src_is_causal: Optional[bool] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + ) -> Tensor: + return super().forward( + src, + tgt, + src_mask, + tgt_mask, + memory_mask, + src_key_padding_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + src_is_causal, + tgt_is_causal, + memory_is_causal, + )