chore: Restructure
This commit is contained in:
parent
8b6c4e17ab
commit
f32f4678e1
62 changed files with 0 additions and 10547 deletions
115
src/dataset_loaders/Dataset.py
Normal file
115
src/dataset_loaders/Dataset.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from os.path import join, curdir
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
Author: Tibo De Peuter
|
||||
"""
|
||||
|
||||
|
||||
class Dataset(TorchDataset, ABC):
|
||||
"""Abstract base class for datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
name: str,
|
||||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1
|
||||
):
|
||||
"""
|
||||
:param root: Path to the dataset root directory
|
||||
:param split: The dataset split, e.g. 'train', 'validation', 'test'
|
||||
:param size: Override the maximum size of the dataset, useful for debugging
|
||||
"""
|
||||
if root is None:
|
||||
root = join(curdir, 'data')
|
||||
|
||||
self._root = join(root, name)
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.size = size
|
||||
self.data = None
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self._root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def process_data(self):
|
||||
if self.size == -1:
|
||||
# Just use the whole dataset
|
||||
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
|
||||
else:
|
||||
# Use only partition, calculate offsets
|
||||
self.chunk_offsets = self.get_offsets()
|
||||
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
||||
|
||||
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
|
||||
|
||||
def get_offsets(self):
|
||||
"""
|
||||
Calculate for each chunk how many bytes came before it
|
||||
"""
|
||||
offsets = [0]
|
||||
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
|
||||
idx = len(offsets) - 1
|
||||
offsets.append(offsets[idx] + len(self.data[idx]))
|
||||
print(offsets)
|
||||
return offsets
|
||||
|
||||
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
||||
item = ''
|
||||
|
||||
# Determine first chunk in which item is located
|
||||
chunk_idx = 0
|
||||
while idx >= offsets[chunk_idx]:
|
||||
chunk_idx += 1
|
||||
chunk_idx -= 1
|
||||
|
||||
# Extract item from chunks
|
||||
chunk = str(self.data[chunk_idx])
|
||||
chunk_start = offsets[chunk_idx]
|
||||
|
||||
chunk_item_start = idx - chunk_start
|
||||
item_len_remaining = context_length + 1
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
while chunk_item_start + item_len_remaining > len(chunk):
|
||||
adding_now_len = len(chunk) - chunk_item_start
|
||||
item += chunk[chunk_item_start:]
|
||||
|
||||
chunk_idx += 1
|
||||
chunk = str(self.data[chunk_idx])
|
||||
|
||||
chunk_item_start = 0
|
||||
item_len_remaining -= adding_now_len
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
|
||||
|
||||
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
|
||||
|
||||
# Transform to tensor
|
||||
data = ''.join(item).encode('utf-8', errors='replace')
|
||||
t = torch.tensor(list(data), dtype=torch.long)
|
||||
x, y = t[:-1], t[-1]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
59
src/dataset_loaders/EnWik9.py
Normal file
59
src/dataset_loaders/EnWik9.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Features, Value
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class EnWik9DataSet(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size)
|
||||
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
# Don't pass split here, dataset only contains training
|
||||
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
|
||||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
63
src/dataset_loaders/LoremIpsumDataset.py
Normal file
63
src/dataset_loaders/LoremIpsumDataset.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from lorem.text import TextLorem
|
||||
from tqdm import tqdm
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
|
||||
_lorem = TextLorem()
|
||||
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get sequence of characters
|
||||
# x_str = self.text[idx: idx + self.context_length]
|
||||
# y_char = self.text[idx + self.context_length]
|
||||
#
|
||||
# # Convert to tensors
|
||||
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
|
||||
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
|
||||
#
|
||||
# if self.transform is not None:
|
||||
# x = self.transform(x)
|
||||
#
|
||||
# return x, y
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
51
src/dataset_loaders/OpenGenomeDataset.py
Normal file
51
src/dataset_loaders/OpenGenomeDataset.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Value, Features
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class OpenGenomeDataset(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
|
||||
|
||||
:param split Either 'train', 'test' or 'validation'
|
||||
:param stage Either 'sample', 'stage1' or 'stage2'.
|
||||
'sample' only provides a 'validation' split
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
||||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# return len(self.data) - self.context_length
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[idx:idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
10
src/dataset_loaders/__init__.py
Normal file
10
src/dataset_loaders/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from .Dataset import Dataset
|
||||
from .EnWik9 import EnWik9DataSet
|
||||
from .LoremIpsumDataset import LoremIpsumDataset
|
||||
from .OpenGenomeDataset import OpenGenomeDataset
|
||||
|
||||
dataset_called: dict[str, type[Dataset]] = {
|
||||
'enwik9': EnWik9DataSet,
|
||||
'lorem_ipsum': LoremIpsumDataset,
|
||||
'opengenome': OpenGenomeDataset
|
||||
}
|
||||
1
src/models/__init__.py
Normal file
1
src/models/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
1
src/models/cnn/__init__.py
Normal file
1
src/models/cnn/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
52
src/models/cnn/cnn.py
Normal file
52
src/models/cnn/cnn.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# 3. Global pooling to collapse sequence length
|
||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||
|
||||
# 4. Final classifier
|
||||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: LongTensor of shape (B, 128), values 0-255
|
||||
"""
|
||||
# embed: (B, 128, embed_dim)
|
||||
x = self.embed(x)
|
||||
|
||||
# conv1d expects (B, C_in, L) → swap dims
|
||||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||||
|
||||
# apply CNN
|
||||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||||
|
||||
# global average pooling over sequence
|
||||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||||
|
||||
# final classifier
|
||||
logits = self.fc(x) # (B, 256)
|
||||
return logits
|
||||
|
||||
|
||||
26
src/trainers/FullTrainer.py
Normal file
26
src/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from .train import train
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
62
src/trainers/OptunaTrainer.py
Normal file
62
src/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Callable
|
||||
|
||||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from ..models.cnn import CNNPredictor
|
||||
from .train import train
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=hidden_dim,
|
||||
embed_dim=embedding_dim,
|
||||
)
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials is not None else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||
3
src/trainers/__init__.py
Normal file
3
src/trainers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .OptunaTrainer import OptunaTrainer
|
||||
from .FullTrainer import FullTrainer
|
||||
from .trainer import Trainer
|
||||
59
src/trainers/train.py
Normal file
59
src/trainers/train.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def train(
|
||||
model: nn.Module,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
epochs: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
device="cuda"
|
||||
) -> tuple[list[float], list[float]]:
|
||||
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for x, y in tqdm(training_loader):
|
||||
x = x.long().to(device) # important for Embedding
|
||||
y = y.long().to(device) # must be (B,) for CE
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = model(x) # (B, 256)
|
||||
loss = loss_fn(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss.append(loss.item())
|
||||
|
||||
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||
|
||||
# ----- validation -----
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
losses = []
|
||||
for x, y in validation_loader:
|
||||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
logits = model(x)
|
||||
loss = loss_fn(logits, y)
|
||||
losses.append(loss.item())
|
||||
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
avg_validation_losses.append(avg_loss)
|
||||
|
||||
return avg_training_losses, avg_validation_losses
|
||||
22
src/trainers/trainer.py
Normal file
22
src/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
1
src/utils/__init__.py
Normal file
1
src/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .utils import *
|
||||
31
src/utils/utils.py
Normal file
31
src/utils/utils.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||
data = torch.tensor(list(data), dtype=torch.long)
|
||||
sample_count = data.shape[0] - context_length
|
||||
x = data.unfold(0, context_length, 1)[:sample_count]
|
||||
y = data[context_length:]
|
||||
return TensorDataset(x, y)
|
||||
|
||||
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||
plt.show()
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
|
||||
plt.plot(train_losses, label="Training loss")
|
||||
plt.plot(validation_losses, label="Validation loss")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss (cross entropy)")
|
||||
plt.legend()
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
plt.savefig("losses.png")
|
||||
|
||||
|
||||
def load_data(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
Reference in a new issue