import torch import torch.nn as nn from torch import Tensor from torch.nn.functional import softmax class CausalConv1d(nn.Conv1d): def __init__(self, input_channels, output_channels, kernel_size, **kwargs): super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs) def forward(self, input: Tensor) -> Tensor: return super().forward(input) class CNNPredictor(nn.Module): def __init__( self, vocab_size=256, num_layers=3, hidden_dim=128, kernel_size=3, dropout_prob=0.1, use_batchnorm=False ): super().__init__() self.embedding = nn.Embedding(vocab_size, hidden_dim) layers = [] in_channels = hidden_dim for _ in range(num_layers): out_channels = hidden_dim layers.append(CausalConv1d(in_channels, out_channels, kernel_size)) if use_batchnorm: layers.append(nn.BatchNorm1d(out_channels)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_prob)) in_channels = out_channels self.network = nn.Sequential(*layers) self.output_layer = nn.Linear(hidden_dim, vocab_size) def forward(self, x: torch.Tensor) -> torch.Tensor: emdedding = self.embedding(x) # B, L, H emdedding = emdedding.transpose(1, 2) # B, H, L prediction = self.network(emdedding) last_prediction = prediction[:, :, -1] return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr.