Merge branch 'main' into portability
This commit is contained in:
commit
bb241154d9
3 changed files with 63 additions and 1 deletions
|
|
@ -1 +1,2 @@
|
|||
from .cnn import CNNPredictor
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import Transformer
|
||||
1
src/models/transformer/__init__.py
Normal file
1
src/models/transformer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .transformer import Transformer
|
||||
60
src/models/transformer/transformer.py
Normal file
60
src/models/transformer/transformer.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Transformer(nn.Transformer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05
|
||||
):
|
||||
super().__init__(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
src_is_causal: Optional[bool] = None,
|
||||
tgt_is_causal: Optional[bool] = None,
|
||||
memory_is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
return super().forward(
|
||||
src,
|
||||
tgt,
|
||||
src_mask,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
src_key_padding_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
src_is_causal,
|
||||
tgt_is_causal,
|
||||
memory_is_causal,
|
||||
)
|
||||
Reference in a new issue