code cleanup

This commit is contained in:
Robin Meersman 2025-11-30 19:21:29 +01:00
parent ea9cf12db0
commit 73d1742cbd
44 changed files with 6 additions and 2835 deletions

View file

@ -0,0 +1,26 @@
from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
from torch.utils.data import Dataset as TorchDataset
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self, root: str, transform: Callable = None):
"""
:param root: Relative path to the dataset root directory
"""
self._root: str = join(curdir, 'data', root)
self.transform = transform
self.dataset = None
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)

43
dataset_loaders/EnWik9.py Normal file
View file

@ -0,0 +1,43 @@
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
from os.path import curdir, join
from typing import Callable
class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable | None = None):
super().__init__()
self.transform = transform
# HuggingFace dataset: string text
path = join(curdir, root)
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
# Extract raw text
text = data["text"]
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
# number of sliding windows
return len(self.data) - self.context_length
def __getitem__(self, idx):
# context window
x = self.data[idx : idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,34 @@
from typing import Callable
import torch
from os.path import curdir, join
from lorem.text import TextLorem
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self, root: str = "data", transform: Callable = None):
super().__init__(root, transform)
# Generate text and convert to bytes
_lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(512))
path = join(curdir, "data")
self._root = path
# Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
self.context_length = 128
def __len__(self):
# Number of possible sequences of length sequence_length
return self.dataset.size(0) - self.context_length
def __getitem__(self, idx):
x = self.dataset[idx: idx + self.context_length]
y = self.dataset[idx + self.context_length]
if self.transform is not None:
x = self.transform(x)
return x, y

View file

@ -0,0 +1,3 @@
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .Dataset import Dataset