60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
from typing import Optional
|
|
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
|
|
class Transformer(nn.Transformer):
|
|
def __init__(
|
|
self,
|
|
d_model=512,
|
|
nhead=8,
|
|
num_encoder_layers=6,
|
|
num_decoder_layers=6,
|
|
dim_feedforward=2048,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
layer_norm_eps=1e-05
|
|
):
|
|
super().__init__(
|
|
d_model=d_model,
|
|
nhead=nhead,
|
|
num_encoder_layers=num_encoder_layers,
|
|
num_decoder_layers=num_decoder_layers,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
layer_norm_eps=layer_norm_eps,
|
|
batch_first=False,
|
|
norm_first=False,
|
|
device=None,
|
|
dtype=None
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
tgt: Tensor,
|
|
src_mask: Optional[Tensor] = None,
|
|
tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
src_is_causal: Optional[bool] = None,
|
|
tgt_is_causal: Optional[bool] = None,
|
|
memory_is_causal: bool = False,
|
|
) -> Tensor:
|
|
return super().forward(
|
|
src,
|
|
tgt,
|
|
src_mask,
|
|
tgt_mask,
|
|
memory_mask,
|
|
src_key_padding_mask,
|
|
tgt_key_padding_mask,
|
|
memory_key_padding_mask,
|
|
src_is_causal,
|
|
tgt_is_causal,
|
|
memory_is_causal,
|
|
)
|