from typing import Callable import torch from datasets import load_dataset from torch import Tensor from .Dataset import Dataset class OpenGenomeDataset(Dataset): """ Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome :param split Either 'train', 'test' or 'validation' :param stage Either 'sample', 'stage1' or 'stage2'. 'sample' only provides a 'validation' split """ def __init__(self, root: str | None = None, split: str = 'train', transform: Callable = None, stage: str = 'stage2'): super().__init__('open_genome', root, transform) data = load_dataset("LongSafari/open-genome", stage) self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace') self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long) # Model uses fixed 128-length context self.context_length = 128 def __len__(self): return len(self.data) - self.context_length def __getitem__(self, item): x = self.data[item: item + self.context_length] y = self.data[item + self.context_length] if self.transform: x = self.transform(x) return x, y