45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.nn.functional import softmax
|
|
|
|
|
|
class CausalConv1d(nn.Conv1d):
|
|
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
|
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return super().forward(input)
|
|
|
|
class CNNPredictor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vocab_size=256,
|
|
num_layers=3,
|
|
hidden_dim=128,
|
|
kernel_size=3,
|
|
dropout_prob=0.1,
|
|
use_batchnorm=False
|
|
):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
|
layers = []
|
|
in_channels = hidden_dim
|
|
for _ in range(num_layers):
|
|
out_channels = hidden_dim
|
|
layers.append(CausalConv1d(in_channels, out_channels, kernel_size))
|
|
if use_batchnorm:
|
|
layers.append(nn.BatchNorm1d(out_channels))
|
|
layers.append(nn.ReLU())
|
|
layers.append(nn.Dropout(dropout_prob))
|
|
in_channels = out_channels
|
|
|
|
self.network = nn.Sequential(*layers)
|
|
self.output_layer = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
emdedding = self.embedding(x) # B, L, H
|
|
emdedding = emdedding.transpose(1, 2) # B, H, L
|
|
prediction = self.network(emdedding)
|
|
last_prediction = prediction[:, :, -1]
|
|
return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr.
|
|
|