chore: Restructure

This commit is contained in:
Tibo De Peuter 2025-12-05 12:37:48 +01:00
parent 8b6c4e17ab
commit f32f4678e1
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
62 changed files with 0 additions and 10547 deletions

View file

@ -0,0 +1,115 @@
from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self,
name: str,
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
):
"""
:param root: Path to the dataset root directory
:param split: The dataset split, e.g. 'train', 'validation', 'test'
:param size: Override the maximum size of the dataset, useful for debugging
"""
if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.split = split
self.transform = transform
self.size = size
self.data = None
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)
def process_data(self):
if self.size == -1:
# Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
else:
# Use only partition, calculate offsets
self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
def get_offsets(self):
"""
Calculate for each chunk how many bytes came before it
"""
offsets = [0]
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
idx = len(offsets) - 1
offsets.append(offsets[idx] + len(self.data[idx]))
print(offsets)
return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
item = ''
# Determine first chunk in which item is located
chunk_idx = 0
while idx >= offsets[chunk_idx]:
chunk_idx += 1
chunk_idx -= 1
# Extract item from chunks
chunk = str(self.data[chunk_idx])
chunk_start = offsets[chunk_idx]
chunk_item_start = idx - chunk_start
item_len_remaining = context_length + 1
assert len(item) + item_len_remaining == context_length + 1
while chunk_item_start + item_len_remaining > len(chunk):
adding_now_len = len(chunk) - chunk_item_start
item += chunk[chunk_item_start:]
chunk_idx += 1
chunk = str(self.data[chunk_idx])
chunk_item_start = 0
item_len_remaining -= adding_now_len
assert len(item) + item_len_remaining == context_length + 1
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
# Transform to tensor
data = ''.join(item).encode('utf-8', errors='replace')
t = torch.tensor(list(data), dtype=torch.long)
x, y = t[:-1], t[-1]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,59 @@
from math import ceil
from typing import Callable
from datasets import load_dataset, Features, Value
from .Dataset import Dataset
class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
):
super().__init__('enwik9', root, split, transform, size)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,63 @@
from math import ceil
from typing import Callable
from lorem.text import TextLorem
from tqdm import tqdm
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
):
super().__init__('lorem_ipsum', root, split, transform, size)
_lorem = TextLorem()
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# Get sequence of characters
# x_str = self.text[idx: idx + self.context_length]
# y_char = self.text[idx + self.context_length]
#
# # Convert to tensors
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
#
# if self.transform is not None:
# x = self.transform(x)
#
# return x, y
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,51 @@
from typing import Callable
from datasets import load_dataset, Value, Features
from .Dataset import Dataset
class OpenGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
:param split Either 'train', 'test' or 'validation'
:param stage Either 'sample', 'stage1' or 'stage2'.
'sample' only provides a 'validation' split
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = -1,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
print("Done initializing dataset")
def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,10 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .OpenGenomeDataset import OpenGenomeDataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset,
'opengenome': OpenGenomeDataset
}

1
src/models/__init__.py Normal file
View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

52
src/models/cnn/cnn.py Normal file
View file

@ -0,0 +1,52 @@
import torch
import torch.nn as nn
class CNNPredictor(nn.Module):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__()
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)
# 2. Convolutional feature extractor
self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
)
# 3. Global pooling to collapse sequence length
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
# 4. Final classifier
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
def forward(self, x):
"""
x: LongTensor of shape (B, 128), values 0-255
"""
# embed: (B, 128, embed_dim)
x = self.embed(x)
# conv1d expects (B, C_in, L) → swap dims
x = x.transpose(1, 2) # (B, embed_dim, 128)
# apply CNN
x = self.conv_layers(x) # (B, hidden_channels, 128)
# global average pooling over sequence
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
# final classifier
logits = self.fc(x) # (B, 256)
return logits

View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

View file

@ -0,0 +1,62 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")

3
src/trainers/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .OptunaTrainer import OptunaTrainer
from .FullTrainer import FullTrainer
from .trainer import Trainer

59
src/trainers/train.py Normal file
View file

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad()
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
return avg_training_losses, avg_validation_losses

22
src/trainers/trainer.py Normal file
View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass

1
src/utils/__init__.py Normal file
View file

@ -0,0 +1 @@
from .utils import *

31
src/utils/utils.py Normal file
View file

@ -0,0 +1,31 @@
import torch
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
data = torch.tensor(list(data), dtype=torch.long)
sample_count = data.shape[0] - context_length
x = data.unfold(0, context_length, 1)[:sample_count]
y = data[context_length:]
return TensorDataset(x, y)
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
plt.show()
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
plt.plot(train_losses, label="Training loss")
plt.plot(validation_losses, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (cross entropy)")
plt.legend()
if show:
plt.show()
plt.savefig("losses.png")
def load_data(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()