from typing import Optional import torch.nn as nn from torch import Tensor class Transformer(nn.Transformer): def __init__( self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-05 ): super().__init__( d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=False, norm_first=False, device=None, dtype=None ) def forward( self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None, memory_is_causal: bool = False, ) -> Tensor: return super().forward( src, tgt, src_mask, tgt_mask, memory_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask, src_is_causal, tgt_is_causal, memory_is_causal, )