59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
from math import ceil
|
|
from typing import Callable
|
|
|
|
from datasets import load_dataset, Features, Value
|
|
|
|
from .Dataset import Dataset
|
|
|
|
|
|
class EnWik9DataSet(Dataset):
|
|
"""
|
|
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
|
|
"""
|
|
|
|
def __init__(self,
|
|
root: str | None = None,
|
|
split: str = 'train',
|
|
transform: Callable | None = None,
|
|
size: int = -1
|
|
):
|
|
super().__init__('enwik9', root, split, transform, size)
|
|
|
|
print(f"Loading from HuggingFace")
|
|
ft = Features({'text': Value('string')})
|
|
# Don't pass split here, dataset only contains training
|
|
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
|
|
self.data = text_chunks['text']
|
|
self.size = size
|
|
|
|
# Model uses fixed 128-length context
|
|
self.context_length = 128
|
|
|
|
self.process_data()
|
|
|
|
# Define splits manually, because they do not exist in the dataset
|
|
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
|
|
|
if self.split == 'train':
|
|
self.start_byte = 0
|
|
self.end_byte = split_point
|
|
elif self.split == 'validation':
|
|
self.start_byte = split_point
|
|
self.end_byte = self.chunk_offsets[-1]
|
|
else:
|
|
raise ValueError("split must be 'train' or 'validation'")
|
|
|
|
print("Done initializing dataset")
|
|
|
|
def __len__(self):
|
|
return self.end_byte - self.start_byte - self.context_length
|
|
|
|
def __getitem__(self, idx):
|
|
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
|
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
|
y = self.tensor[self.start_byte + idx + self.context_length]
|
|
|
|
if self.transform:
|
|
x = self.transform(x)
|
|
|
|
return x, y
|