from datasets import load_dataset from torch.utils.data import Dataset import torch from os.path import curdir, join from typing import Callable class EnWik9DataSet(Dataset): def __init__(self, root: str = "data", transform: Callable | None = None): super().__init__() self.transform = transform # HuggingFace dataset: string text path = join(curdir, root) data = load_dataset("haukur/enwik9", cache_dir=path, split="train") # Extract raw text text = data["text"] # Convert text (Python string) → bytes → tensor of ints 0–255 # UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors byte_data = "".join(text).encode("utf-8", errors="replace") self.data = torch.tensor(list(byte_data), dtype=torch.long) # Model uses fixed 128-length context self.context_length = 128 def __len__(self): # number of sliding windows return len(self.data) - self.context_length def __getitem__(self, idx): # context window x = self.data[idx : idx + self.context_length] # next byte target y = self.data[idx + self.context_length] if self.transform: x = self.transform(x) return x, y