feat: new CNN, start of creating graphs

This commit is contained in:
RobinMeersman 2025-12-14 18:36:40 +01:00
parent 17e0b52600
commit 5bb254d6c2
7 changed files with 151 additions and 49 deletions

View file

@ -50,12 +50,16 @@ class AutoEncoder(Model):
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor:

View file

@ -1,57 +1,51 @@
import torch
import torch.nn as nn
from src.models import Model
class CNNPredictor(Model):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__(nn.CrossEntropyLoss())
self,
vocab_size: int = 256,
hidden_dim: int = 128,
):
super().__init__(loss_function=nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)
# 2. Convolutional feature extractor
self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
# Treat bytes as a 1D signal with 1 channel
self.feature_extractor = nn.Sequential(
nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.MaxPool1d(kernel_size=2),
nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(2 * hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.MaxPool1d(kernel_size=2),
nn.Conv1d(2 * hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(2 * hidden_dim),
nn.ReLU(),
)
# 3. Global pooling to collapse sequence length
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
# Collapse sequence dimension → fixed-size representation
self.global_pool = nn.AdaptiveAvgPool1d(1) # (B, hidden_dim, 1)
# 4. Final classifier
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
# Classification head
self.classifier = nn.Linear(2 * hidden_dim, vocab_size)
def forward(self, x):
def forward(self, x: torch.LongTensor) -> torch.Tensor:
"""
x: LongTensor of shape (B, 128), values 0-255
x: (B, L) LongTensor with values in [0, 255]
returns: logits (B, 256)
"""
# embed: (B, 128, embed_dim)
x = self.embed(x)
# conv1d expects (B, C_in, L) → swap dims
x = x.transpose(1, 2) # (B, embed_dim, 128)
# apply CNN
x = self.conv_layers(x) # (B, hidden_channels, 128)
# global average pooling over sequence
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
# final classifier
logits = self.fc(x) # (B, 256)
return logits
# Convert bytes to float signal
x = x.float() / 255.0 # (B, L)
x = x.unsqueeze(1) # (B, 1, L)
features = self.feature_extractor(x) # (B, 2 * hidden_dim, L')
pooled = self.global_pool(features) # (B, 2 * hidden_dim, 1)
pooled = pooled.squeeze(-1) # (B, 2 * hidden_dim)
logits = self.classifier(pooled) # (B, 256)
return logits