fix: conflicts

This commit is contained in:
RobinMeersman 2025-12-13 11:27:46 +01:00
commit b178c097d8
90 changed files with 2034 additions and 11145 deletions

14
src/models/Model.py Normal file
View file

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from torch import nn
class Model(nn.Module, ABC):
@abstractmethod
def __init__(self, loss_function = None):
super().__init__()
self._loss_function = loss_function
@property
def loss_function(self):
return self._loss_function

8
src/models/__init__.py Normal file
View file

@ -0,0 +1,8 @@
from .Model import Model
from .cnn import CNNPredictor
from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': ByteTransformer
}

18
src/models/autoencoder.py Normal file
View file

@ -0,0 +1,18 @@
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Encoder, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass

View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

54
src/models/cnn/cnn.py Normal file
View file

@ -0,0 +1,54 @@
import torch.nn as nn
from src.models import Model
class CNNPredictor(Model):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__(nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)
# 2. Convolutional feature extractor
self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
)
# 3. Global pooling to collapse sequence length
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
# 4. Final classifier
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
def forward(self, x):
"""
x: LongTensor of shape (B, 128), values 0-255
"""
# embed: (B, 128, embed_dim)
x = self.embed(x)
# conv1d expects (B, C_in, L) → swap dims
x = x.transpose(1, 2) # (B, embed_dim, 128)
# apply CNN
x = self.conv_layers(x) # (B, hidden_channels, 128)
# global average pooling over sequence
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
# final classifier
logits = self.fc(x) # (B, 256)
return logits

View file

@ -0,0 +1 @@
from .transformer import ByteTransformer

View file

@ -0,0 +1,70 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor, arange
from src.models import Model
class LearnedPositionalEncoding(Model):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [seq, batch, d_model]
seq_len = x.size(0)
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
return x + self.pos_emb(positions) # broadcast over batch
class ByteTransformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05,
max_len=128
):
super().__init__()
self.src_embedding = nn.Embedding(256, d_model)
self.tgt_embedding = nn.Embedding(256, d_model)
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None,
)
self.output_proj = nn.Linear(d_model, 256)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,
src: Tensor,
tgt: Tensor,
) -> Tensor:
src_embeds = self.src_embedding(src)
tgt_embeds = self.tgt_embedding(tgt)
src_pos = self.src_pos(src_embeds)
tgt_pos = self.tgt_pos(tgt_embeds)
return self.output_proj(self.transformer(src_pos, tgt_pos))