Streamline datasets

This commit is contained in:
Tibo De Peuter 2025-12-04 23:13:16 +01:00
parent 849bcd7b77
commit befb1a96a5
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 222 additions and 64 deletions

View file

@ -1,7 +1,7 @@
from math import ceil
from typing import Callable
import torch
from datasets import load_dataset
from datasets import load_dataset, Features, Value
from .Dataset import Dataset
@ -10,33 +10,48 @@ class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text
data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
):
super().__init__('enwik9', root, split, transform, size)
# Extract raw text
text = data["text"]
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
# number of sliding windows
return len(self.data) - self.context_length
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# context window
x = self.data[idx: idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)