from math import ceil from typing import Callable from datasets import load_dataset, Features, Value from .Dataset import Dataset class EnWik9DataSet(Dataset): """ Hugging Face: https://huggingface.co/datasets/haukur/enwik9 """ def __init__(self, root: str | None = None, split: str = 'train', transform: Callable | None = None, size: int = -1 ): super().__init__('enwik9', root, split, transform, size) print(f"Loading from HuggingFace") ft = Features({'text': Value('string')}) # Don't pass split here, dataset only contains training text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft) self.data = text_chunks['text'] self.size = size # Model uses fixed 128-length context self.context_length = 128 self.process_data() # Define splits manually, because they do not exist in the dataset split_point = ceil(self.chunk_offsets[-1] * 0.8) if self.split == 'train': self.start_byte = 0 self.end_byte = split_point elif self.split == 'validation': self.start_byte = split_point self.end_byte = self.chunk_offsets[-1] else: raise ValueError("split must be 'train' or 'validation'") print("Done initializing dataset") def __len__(self): return self.end_byte - self.start_byte - self.context_length def __getitem__(self, idx): # return self.get_chunked_item(idx, self.chunk_offsets, self.context_length) x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length] y = self.tensor[self.start_byte + idx + self.context_length] if self.transform: x = self.transform(x) return x, y