fix: fixed model shapes + redit training loop
This commit is contained in:
parent
ed44d5b283
commit
eb4a014aa1
3 changed files with 68 additions and 56 deletions
|
|
@ -1,45 +1,52 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import softmax
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
||||
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return super().forward(input)[:, :, :input.size(-1)]
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
num_layers=3,
|
||||
hidden_dim=128,
|
||||
kernel_size=3,
|
||||
dropout_prob=0.1,
|
||||
use_batchnorm=False
|
||||
):
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_channels=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
||||
layers = []
|
||||
in_channels = hidden_dim
|
||||
for _ in range(num_layers):
|
||||
out_channels = hidden_dim
|
||||
layers.append(CausalConv1d(in_channels, out_channels, kernel_size))
|
||||
if use_batchnorm:
|
||||
layers.append(nn.BatchNorm1d(out_channels))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(nn.Dropout(dropout_prob))
|
||||
in_channels = out_channels
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
self.output_layer = nn.Linear(hidden_dim, vocab_size)
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 3. Global pooling to collapse sequence length
|
||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||
|
||||
# 4. Final classifier
|
||||
self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: LongTensor of shape (B, 128), values 0-255
|
||||
"""
|
||||
# embed: (B, 128, embed_dim)
|
||||
x = self.embed(x)
|
||||
|
||||
# conv1d expects (B, C_in, L) → swap dims
|
||||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||||
|
||||
# apply CNN
|
||||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||||
|
||||
# global average pooling over sequence
|
||||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||||
|
||||
# final classifier
|
||||
logits = self.fc(x) # (B, 256)
|
||||
return logits
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
emdedding = self.embedding(x) # B, L, H
|
||||
emdedding = emdedding.transpose(1, 2) # B, H, L
|
||||
prediction = self.network(emdedding)
|
||||
last_prediction = prediction[:, :, -1]
|
||||
return self.output_layer(last_prediction)
|
||||
|
||||
|
|
|
|||
Reference in a new issue