52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
class CNNPredictor(nn.Module):
|
||
def __init__(
|
||
self,
|
||
vocab_size=256,
|
||
embed_dim=64,
|
||
hidden_dim=128,
|
||
):
|
||
super().__init__()
|
||
|
||
# 1. Embedding: maps bytes (0–255) → vectors
|
||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||
|
||
# 2. Convolutional feature extractor
|
||
self.conv_layers = nn.Sequential(
|
||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||
nn.ReLU(),
|
||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||
nn.ReLU(),
|
||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||
nn.ReLU(),
|
||
)
|
||
|
||
# 3. Global pooling to collapse sequence length
|
||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||
|
||
# 4. Final classifier
|
||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||
|
||
def forward(self, x):
|
||
"""
|
||
x: LongTensor of shape (B, 128), values 0-255
|
||
"""
|
||
# embed: (B, 128, embed_dim)
|
||
x = self.embed(x)
|
||
|
||
# conv1d expects (B, C_in, L) → swap dims
|
||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||
|
||
# apply CNN
|
||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||
|
||
# global average pooling over sequence
|
||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||
|
||
# final classifier
|
||
logits = self.fc(x) # (B, 256)
|
||
return logits
|
||
|
||
|