fix: conflicts

This commit is contained in:
RobinMeersman 2025-12-13 11:27:46 +01:00
commit b178c097d8
90 changed files with 2034 additions and 11145 deletions

51
src/args.py Normal file
View file

@ -0,0 +1,51 @@
from argparse import ArgumentParser
from src.dataset_loaders import dataset_called
def parse_arguments():
parser = ArgumentParser(prog="NeuralCompression")
parser.add_argument("--debug", "-d", action="store_true", required=False,
help="Enable debug mode: smaller datasets, more information")
parser.add_argument("--verbose", "-v", action="store_true", required=False,
help="Enable verbose mode")
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
parser.add_argument("--device", required=False, help="Override the device to use")
dataparser = ArgumentParser(add_help=False)
dataparser.add_argument("--data-root", type=str, required=False)
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
dataparser.add_argument("--size", "-s", type=int, required=False,
help="Size of the subset of the dataset to use")
modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model", "-m", type=str, required=False,
help="Which model to use")
modelparser.add_argument("--model-load-path", type=str, required=False,
help="Filepath to the model to load")
modelparser.add_argument("--model-save-path", type=str, required=False,
help="Filepath to the model to save")
modelparser.add_argument("--context", type=int, required=False,
help="Context length to use")
fileparser = ArgumentParser(add_help=False)
fileparser.add_argument("--input-file", "-i", required=False, type=str)
fileparser.add_argument("--output-file", "-o", required=False, type=str)
subparsers = parser.add_subparsers(dest="mode", required=True,
help="Mode to run in")
train_parser = subparsers.add_parser("train",
parents=[dataparser, modelparser],
help="Do a full training")
train_parser.add_argument("--method",
choices=["fetch", "optuna", "full"], required=True,
help="Method to use for training")
subparsers.add_parser("compress", parents=[modelparser, fileparser],
help="Compress a file")
subparsers.add_parser("decompress", parents=[modelparser, fileparser],
help="Decompress a file")
return parser.parse_args(), parser.print_help

View file

@ -0,0 +1,132 @@
from abc import abstractmethod, ABC
from itertools import accumulate
from os.path import join, curdir
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self,
name: str,
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1,
context_length: int = 1024
):
"""
:param root: Path to the dataset root directory
:param split: The dataset split, e.g. 'train', 'validation', 'test'
:param size: Override the maximum size of the dataset, useful for debugging
"""
if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.split = split
self.transform = transform
self.size = size
self.context_length = context_length
self.data = None
print(f"Context length: {self.context_length}")
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)
def process_data(self):
self.chunk_offsets = self.get_offsets()
if self.size == -1:
# Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace')
else:
# Use only partition, calculate offsets
self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False))
.encode('utf-8', errors='replace'))
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True)
def get_offsets(self):
"""
Calculate for each chunk how many bytes came before it
"""
data = self.data
size = self.size
if size == -1:
return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))]
offsets = [0]
total = 0
append = offsets.append
for chunk in tqdm(data):
if total >= size:
break
total += len(chunk)
append(total)
return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
item = ''
# Determine first chunk in which item is located
chunk_idx = 0
while idx >= offsets[chunk_idx]:
chunk_idx += 1
chunk_idx -= 1
# Extract item from chunks
chunk = str(self.data[chunk_idx])
chunk_start = offsets[chunk_idx]
chunk_item_start = idx - chunk_start
item_len_remaining = context_length + 1
assert len(item) + item_len_remaining == context_length + 1
while chunk_item_start + item_len_remaining > len(chunk):
adding_now_len = len(chunk) - chunk_item_start
item += chunk[chunk_item_start:]
chunk_idx += 1
chunk = str(self.data[chunk_idx])
chunk_item_start = 0
item_len_remaining -= adding_now_len
assert len(item) + item_len_remaining == context_length + 1
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
# Transform to tensor
data = ''.join(item).encode('utf-8', errors='replace')
t = torch.tensor(list(data), dtype=torch.long)
x, y = t[:-1], t[-1]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,57 @@
from math import ceil
from typing import Callable
from datasets import load_dataset, Features, Value
from .Dataset import Dataset
class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1,
context_length: int = 1024
):
super().__init__('enwik9', root, split, transform, size, context_length)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
self.data = text_chunks['text']
self.size = size
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,45 @@
from typing import Callable
from datasets import load_dataset
from .Dataset import Dataset
class HumanReferenceGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/InstaDeepAI/human_reference_genome
:param split: 'train' | 'validation' | 'test'
:param config: '6kbp' | '12kbp' (chunk length in the HF builder config)
"""
def __init__(self,
root: str | None = None,
split: str = "train",
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
config: str = "6kbp",
):
super().__init__("human_reference_genome", root, split, transform, size, context_length)
print(f"Loading from HuggingFace (config: {config}, split: {split})")
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
cache_dir=self.root, trust_remote_code=True)
self.data = data["sequence"]
self.process_data()
print("Done initializing dataset")
def __len__(self):
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
x = self.tensor[idx: idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,62 @@
from math import ceil
from typing import Callable
from lorem.text import TextLorem
from tqdm import tqdm
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30,
context_length: int = 1024
):
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
_lorem = TextLorem()
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# Get sequence of characters
# x_str = self.text[idx: idx + self.context_length]
# y_char = self.text[idx + self.context_length]
#
# # Convert to tensors
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
#
# if self.transform is not None:
# x = self.transform(x)
#
# return x, y
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,49 @@
from typing import Callable
from datasets import load_dataset, Value, Features
from .Dataset import Dataset
class OpenGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
:param split Either 'train', 'test' or 'validation'
:param stage Either 'sample', 'stage1' or 'stage2'.
'sample' only provides a 'validation' split
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size, context_length)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data = data['text']
self.size = size
self.process_data()
print("Done initializing dataset")
def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,12 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet
from .HumanReferenceGenomeDataset import HumanReferenceGenomeDataset
from .LoremIpsumDataset import LoremIpsumDataset
from .OpenGenomeDataset import OpenGenomeDataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset,
'opengenome': OpenGenomeDataset,
'humanreference': HumanReferenceGenomeDataset
}

14
src/models/Model.py Normal file
View file

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from torch import nn
class Model(nn.Module, ABC):
@abstractmethod
def __init__(self, loss_function = None):
super().__init__()
self._loss_function = loss_function
@property
def loss_function(self):
return self._loss_function

8
src/models/__init__.py Normal file
View file

@ -0,0 +1,8 @@
from .Model import Model
from .cnn import CNNPredictor
from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': ByteTransformer
}

18
src/models/autoencoder.py Normal file
View file

@ -0,0 +1,18 @@
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Encoder, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass

View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

54
src/models/cnn/cnn.py Normal file
View file

@ -0,0 +1,54 @@
import torch.nn as nn
from src.models import Model
class CNNPredictor(Model):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__(nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)
# 2. Convolutional feature extractor
self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.ReLU(),
)
# 3. Global pooling to collapse sequence length
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
# 4. Final classifier
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
def forward(self, x):
"""
x: LongTensor of shape (B, 128), values 0-255
"""
# embed: (B, 128, embed_dim)
x = self.embed(x)
# conv1d expects (B, C_in, L) → swap dims
x = x.transpose(1, 2) # (B, embed_dim, 128)
# apply CNN
x = self.conv_layers(x) # (B, hidden_channels, 128)
# global average pooling over sequence
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
# final classifier
logits = self.fc(x) # (B, 256)
return logits

View file

@ -0,0 +1 @@
from .transformer import ByteTransformer

View file

@ -0,0 +1,70 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor, arange
from src.models import Model
class LearnedPositionalEncoding(Model):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [seq, batch, d_model]
seq_len = x.size(0)
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
return x + self.pos_emb(positions) # broadcast over batch
class ByteTransformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05,
max_len=128
):
super().__init__()
self.src_embedding = nn.Embedding(256, d_model)
self.tgt_embedding = nn.Embedding(256, d_model)
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None,
)
self.output_proj = nn.Linear(d_model, 256)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,
src: Tensor,
tgt: Tensor,
) -> Tensor:
src_embeds = self.src_embedding(src)
tgt_embeds = self.tgt_embedding(tgt)
src_pos = self.src_pos(src_embeds)
tgt_pos = self.tgt_pos(tgt_embeds)
return self.output_proj(self.transformer(src_pos, tgt_pos))

30
src/process.py Normal file
View file

@ -0,0 +1,30 @@
import torch
def compress(
device,
model_path: str,
output_file: str,
input_file: str | None = None
):
# Get input to compress
if input_file:
with open(input_file, "rb") as file:
byte_data = file.read()
else:
# Read from stdin
text = input()
byte_data = text.encode('utf-8', errors='replace')
tensor = torch.tensor(list(byte_data), dtype=torch.long)
print(tensor)
# Get model
model = torch.load(model_path, weights_only=False)
# TODO Feed to model for compression, store result
return
def decompress():
return NotImplementedError("Decompression is not implemented yet")

79
src/train.py Normal file
View file

@ -0,0 +1,79 @@
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called
from src.models import model_called
from src.trainers import OptunaTrainer, Trainer, FullTrainer
def train(
device,
dataset: str,
data_root: str,
n_trials: int | None = None,
size: int | None = None,
context_length: int | None = None,
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None,
results_dir: str = 'results'
):
batch_size = 64
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name:
print("Creating model")
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path, weights_only=False)
dataset_common_args = {
'root': data_root,
'transform': lambda x: x.to(device),
}
if size:
dataset_common_args['size'] = size
if context_length:
dataset_common_args['context_length'] = context_length
print("Loading in the dataset...")
if dataset in dataset_called:
training_set = dataset_called[dataset](split='train', **dataset_common_args)
validate_set = dataset_called[dataset](split='validation', **dataset_common_args)
else:
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
if method == 'fetch':
# TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0)
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
print("Training")
best_model = trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
n_epochs=n_trials,
device=device
)
print("Saving model...")
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
# Make sure path exists
Path(f).parent.mkdir(parents=True, exist_ok=True)
torch.save(best_model, f)
print(f"Saved model to '{f}'")

View file

@ -0,0 +1,29 @@
from torch import nn
from torch.utils.data import DataLoader
from .train import train
from .trainer import Trainer
from ..models import Model
from ..utils import print_losses
class FullTrainer(Trainer):
def __init__(self, results_dir: str = 'results'):
self.results_dir = results_dir
def execute(
self,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int | None,
device: str
) -> nn.Module:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
device=device)
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
return model

View file

@ -0,0 +1,72 @@
import optuna
import optuna.trial as tr
from torch import nn
from torch.utils.data import DataLoader
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer
def create_model(trial: tr.Trial, model: nn.Module):
match model.__class__:
case CNNPredictor.__class__:
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case ByteTransformer.__class__:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model(
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
model: Model,
device: str
):
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int,
device: str
) -> nn.Module:
study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = model(
**best_params
)
return best_model

3
src/trainers/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .OptunaTrainer import OptunaTrainer
from .FullTrainer import FullTrainer
from .trainer import Trainer

81
src/trainers/train.py Normal file
View file

@ -0,0 +1,81 @@
from typing import Callable
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from ..models import ByteTransformer, Model
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
if isinstance(model, ByteTransformer):
tgt_in = torch.cat([
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
x[:, :-1]
], dim=1)
logits = model(x, tgt_in)
# only consider the last time step of the model where the full context
# is available
return logits[:, -1, :]
return model(x)
def train(
model: Model,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable,
epochs: int | None = None,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
if epochs is None:
epochs = 100
for epoch in range(epochs):
model.train()
total_loss = []
for x, y in tqdm(training_loader):
# size (B, 128)
x = x.long().to(device)
# size (B)
y = y.long().to(device)
optimizer.zero_grad()
logits = _forward(model, x, device)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = _forward(model, x, device)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
return avg_training_losses, avg_validation_losses

19
src/trainers/trainer.py Normal file
View file

@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int | None,
device: str
) -> nn.Module:
pass

1
src/utils/__init__.py Normal file
View file

@ -0,0 +1 @@
from .utils import *

175
src/utils/benchmark.py Normal file
View file

@ -0,0 +1,175 @@
"""Utilities functions for benchmarking."""
import json
import string
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from os import getpid, path
from pathlib import Path
from random import choices
from subprocess import DEVNULL, PIPE, CalledProcessError, TimeoutExpired, run
from timeit import timeit
from typing import Callable
from memray import Tracker
from ..utils.benchmark_dataclasses import BenchmarkItem, BenchmarkResult
log = getLogger(__name__)
def get_commit_hash() -> str:
"""
Get the commit hash of the current git repository.
If not working in a git repository, return a random string that looks like a commit hash.
"""
try:
return run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=PIPE,
stderr=DEVNULL,
text=True,
).stdout.strip()
except CalledProcessError as e:
log.error(
"Could not determine the commit hash. Are you using a git repository?:\n%s",
e,
)
log.error("Using a random string as commit hash.")
return "".join(choices(string.hexdigits[:-6], k=40))
def init_stat_file(stat_file: Path, header: str) -> int:
"""Initialize a statistics file with a header."""
# Check if the parent directory exists
stat_file.parent.mkdir(parents=True, exist_ok=True)
# Check if the file exists
if stat_file.exists():
# Nothing left to do
return 0
# Initialize the file by writing the header to it.
log.debug("Initializing statistics file %s", stat_file)
stat_file.touch()
stat_file.write_text(f"{header}\n", encoding="utf-8")
return 1
def track_time_memory(task: Callable, result: BenchmarkResult, mem_file: Path, mem_json_file: Path):
"""Track the time and memory consumption of a task."""
def task_with_result():
result.value = task()
# Measure memory consumption
with Tracker(file_name=mem_file, native_traces=True, follow_fork=True, memory_interval_ms=1):
try:
# Measure runtime
result.runtime = timeit(task_with_result, number=1, globals=globals())
except BaseException as e:
log.error("Error while timing the program:\n%s", e, exc_info=True)
return None
# Convert binary memory file into JSON.
try:
run(
[
"python",
"-m",
"memray",
"stats",
"--json",
"--num-largest",
"1",
"--output",
mem_json_file,
mem_file,
],
check=True,
timeout=100,
stdout=DEVNULL,
)
# Parse JSON to get peak_memory
mem_results = json.loads(mem_json_file.read_text(encoding="utf-8"))
result.peak_memory = mem_results["metadata"]["peak_memory"]
except CalledProcessError as e:
log.error(
"Something went wrong while processing the memray memory file %s:\n%s",
mem_file,
e,
)
except TimeoutExpired as e:
log.error(
"Timeout expired while processing the memray memory file %s:\n%s}",
mem_file,
e,
)
return result
def execute_benchmark(
benchmark_item: BenchmarkItem,
results_dir: str | Path,
timeout: int = 100,
) -> BenchmarkResult:
"""Execute a benchmark and track its runtime and peak memory consumption."""
mem_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.mem"))
mem_json_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.json"))
result = BenchmarkResult(benchmark_item)
try:
# Time and track memory usage
# Kill after timeout in seconds
with ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: track_time_memory(
lambda: benchmark_item.task(**benchmark_item.arguments), result, mem_file, mem_json_file
)
)
executed_result = future.result(timeout=timeout)
if executed_result is not None:
result = executed_result
log.info(
"PID %d: %s finished [%.6f seconds, %d bytes]",
getpid(),
benchmark_item.get_method(),
result.runtime,
result.peak_memory,
)
except TimeoutError:
log.error("Timeout expired while running the benchmark_suite, cleaning up now.")
log.info(
"PID %d: %s failed after timeout (%d seconds)",
getpid(),
benchmark_item.get_method(),
timeout,
)
finally:
# Clean up memory dump file to save disk space.
mem_file.unlink()
return result
if __name__ == "__main__":
import hydra
# Dummy example, read the contents of the dataset
def _read_contents(filename):
with open(filename, encoding="utf-8") as f:
log.info("Dataset content: %s", f.read())
def _read_contents_wrapper(cfg):
return _read_contents(cfg.dataset.path)
hydra_wrapped = hydra.main(config_path="../../config", config_name="config", version_base="1.2")(
_read_contents_wrapper
)()

View file

@ -0,0 +1,79 @@
"""
Benchmark data classes.
This module contains the BenchmarkResult class which is used to store and print the results of a
benchmark_suite.
"""
from dataclasses import dataclass
from typing import Any, Callable
@dataclass(init=True)
class BenchmarkItem:
"""A class used to represent a benchmark_suite (iteration)."""
task: Callable
arguments: dict
def __str__(self) -> str:
"""String representation of the BenchmarkItem object."""
return self.get_in_data_format()
def get_method(self) -> str:
"""
Format the method as if it were a function call.
"""
method_name = self.task.__name__
arguments = ", ".join(
f'{key}={str(value)[:15]}'
for key, value in self.arguments.items()
)
return f"{method_name}({arguments})"
def get_in_data_format(self) -> str:
"""
Format the benchmark_suite item to be printed to a .dat file.
"""
# Flatten out arguments
values = list(self.__dict__.values())
values[1:2] = values[1].values()
return " ".join(map(str, values))
def get_header(self) -> str:
"""
Returns the header which is just the names of the fields separated by spaces.
"""
return " ".join(self.__dict__.keys())
@dataclass(init=True)
class BenchmarkResult:
"""A class used to represent the result of a benchmark_suite."""
benchmark_item: BenchmarkItem
runtime: float = 0
peak_memory: int = 0
value: Any = None
def __str__(self) -> str:
"""String representation of the BenchmarkResult object."""
return self.get_in_data_format()
def get_in_data_format(self) -> str:
"""
Format the benchmark_suite result to be printed to a .dat file.
"""
return " ".join(map(str, self.__dict__.values()))
def get_header(self) -> str:
"""
Returns the header which is just the names of the fields separated by spaces.
"""
# Get header of the BenchmarkItem
keys = list(self.__annotations__.keys())
keys[0:1] = self.benchmark_item.__annotations__.keys()
keys[1:2] = self.benchmark_item.arguments.keys()
return " ".join(keys)

64
src/utils/utils.py Normal file
View file

@ -0,0 +1,64 @@
import csv
from os import path
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
data = torch.tensor(list(data), dtype=torch.long)
sample_count = data.shape[0] - context_length
x = data.unfold(0, context_length, 1)[:sample_count]
y = data[context_length:]
return TensorDataset(x, y)
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
plt.show()
def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False):
plt.plot(train_losses, label="Training loss")
plt.plot(validation_losses, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (cross entropy)")
plt.legend()
if show:
plt.show()
if filename is None:
filename = path.join("results", "losses.png")
print(f"Saving losses to {filename}...")
plt.savefig(filename)
# Also write to CSV file
with open(filename.replace(".png", ".csv"), "w") as f:
writer = csv.writer(f)
writer.writerow(["epoch", "train_loss", "validation_loss"])
for i in range(len(train_losses)):
writer.writerow([i, train_losses[i], validation_losses[i]])
print("Done")
def determine_device():
# NVIDIA GPUs (most HPC clusters)
if torch.cuda.is_available():
return torch.device("cuda")
# Apple Silicon (macOS)
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
# Intel GPUs (oneAPI)
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu")
else:
return torch.device("cpu")
def load_data(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()