chore: Restructure

This commit is contained in:
Tibo De Peuter 2025-12-05 12:37:48 +01:00
parent 8b6c4e17ab
commit f32f4678e1
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
62 changed files with 0 additions and 10547 deletions

View file

@ -0,0 +1,115 @@
from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self,
name: str,
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
):
"""
:param root: Path to the dataset root directory
:param split: The dataset split, e.g. 'train', 'validation', 'test'
:param size: Override the maximum size of the dataset, useful for debugging
"""
if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.split = split
self.transform = transform
self.size = size
self.data = None
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)
def process_data(self):
if self.size == -1:
# Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
else:
# Use only partition, calculate offsets
self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
def get_offsets(self):
"""
Calculate for each chunk how many bytes came before it
"""
offsets = [0]
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
idx = len(offsets) - 1
offsets.append(offsets[idx] + len(self.data[idx]))
print(offsets)
return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
item = ''
# Determine first chunk in which item is located
chunk_idx = 0
while idx >= offsets[chunk_idx]:
chunk_idx += 1
chunk_idx -= 1
# Extract item from chunks
chunk = str(self.data[chunk_idx])
chunk_start = offsets[chunk_idx]
chunk_item_start = idx - chunk_start
item_len_remaining = context_length + 1
assert len(item) + item_len_remaining == context_length + 1
while chunk_item_start + item_len_remaining > len(chunk):
adding_now_len = len(chunk) - chunk_item_start
item += chunk[chunk_item_start:]
chunk_idx += 1
chunk = str(self.data[chunk_idx])
chunk_item_start = 0
item_len_remaining -= adding_now_len
assert len(item) + item_len_remaining == context_length + 1
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
# Transform to tensor
data = ''.join(item).encode('utf-8', errors='replace')
t = torch.tensor(list(data), dtype=torch.long)
x, y = t[:-1], t[-1]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,59 @@
from math import ceil
from typing import Callable
from datasets import load_dataset, Features, Value
from .Dataset import Dataset
class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
):
super().__init__('enwik9', root, split, transform, size)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,63 @@
from math import ceil
from typing import Callable
from lorem.text import TextLorem
from tqdm import tqdm
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
):
super().__init__('lorem_ipsum', root, split, transform, size)
_lorem = TextLorem()
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# Get sequence of characters
# x_str = self.text[idx: idx + self.context_length]
# y_char = self.text[idx + self.context_length]
#
# # Convert to tensors
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
#
# if self.transform is not None:
# x = self.transform(x)
#
# return x, y
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,51 @@
from typing import Callable
from datasets import load_dataset, Value, Features
from .Dataset import Dataset
class OpenGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
:param split Either 'train', 'test' or 'validation'
:param stage Either 'sample', 'stage1' or 'stage2'.
'sample' only provides a 'validation' split
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = -1,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
print("Done initializing dataset")
def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,10 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .OpenGenomeDataset import OpenGenomeDataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset,
'opengenome': OpenGenomeDataset
}