import torch import torch.nn as nn class CNNPredictor(nn.Module): def __init__( self, vocab_size=256, embed_dim=64, hidden_channels=128, ): super().__init__() # 1. Embedding: maps bytes (0–255) → vectors self.embed = nn.Embedding(vocab_size, embed_dim) # 2. Convolutional feature extractor self.conv_layers = nn.Sequential( nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2), nn.ReLU(), nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2), nn.ReLU(), nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2), nn.ReLU(), ) # 3. Global pooling to collapse sequence length self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1) # 4. Final classifier self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256) def forward(self, x): """ x: LongTensor of shape (B, 128), values 0-255 """ # embed: (B, 128, embed_dim) x = self.embed(x) # conv1d expects (B, C_in, L) → swap dims x = x.transpose(1, 2) # (B, embed_dim, 128) # apply CNN x = self.conv_layers(x) # (B, hidden_channels, 128) # global average pooling over sequence x = self.pool(x).squeeze(-1) # (B, hidden_channels) # final classifier logits = self.fc(x) # (B, 256) return logits