Merge branch 'main' into portability

This commit is contained in:
Tibo De Peuter 2025-12-06 20:24:13 +01:00
commit bb241154d9
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
3 changed files with 63 additions and 1 deletions

View file

@ -1 +1,2 @@
from .cnn import CNNPredictor
from .cnn import CNNPredictor
from .transformer import Transformer

View file

@ -0,0 +1 @@
from .transformer import Transformer

View file

@ -0,0 +1,60 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
class Transformer(nn.Transformer):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05
):
super().__init__(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None
)
def forward(
self,
src: Tensor,
tgt: Tensor,
src_mask: Optional[Tensor] = None,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: Optional[bool] = None,
tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False,
) -> Tensor:
return super().forward(
src,
tgt,
src_mask,
tgt_mask,
memory_mask,
src_key_padding_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
src_is_causal,
tgt_is_causal,
memory_is_causal,
)