code cleanup
This commit is contained in:
parent
ea9cf12db0
commit
73d1742cbd
44 changed files with 6 additions and 2835 deletions
1
models/cnn/__init__.py
Normal file
1
models/cnn/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
52
models/cnn/cnn.py
Normal file
52
models/cnn/cnn.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 3. Global pooling to collapse sequence length
|
||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||
|
||||
# 4. Final classifier
|
||||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: LongTensor of shape (B, 128), values 0-255
|
||||
"""
|
||||
# embed: (B, 128, embed_dim)
|
||||
x = self.embed(x)
|
||||
|
||||
# conv1d expects (B, C_in, L) → swap dims
|
||||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||||
|
||||
# apply CNN
|
||||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||||
|
||||
# global average pooling over sequence
|
||||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||||
|
||||
# final classifier
|
||||
logits = self.fc(x) # (B, 256)
|
||||
return logits
|
||||
|
||||
|
||||
Reference in a new issue