from typing import Callable from datasets import load_dataset, Value, Features from .Dataset import Dataset class OpenGenomeDataset(Dataset): """ Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome :param split Either 'train', 'test' or 'validation' :param stage Either 'sample', 'stage1' or 'stage2'. 'sample' only provides a 'validation' split """ def __init__(self, root: str | None = None, split: str = 'train', transform: Callable = None, size: int = -1, stage: str = 'stage2' ): super().__init__('open_genome', root, split, transform, size) print(f"Loading from HuggingFace (stage: {stage}, split: {split})") ft = Features({'text': Value('string')}) data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft) self.data = data['text'] self.size = size # Model uses fixed 128-length context self.context_length = 128 self.process_data() print("Done initializing dataset") def __len__(self): # return len(self.data) - self.context_length return self.chunk_offsets[-1] - self.context_length def __getitem__(self, idx): # return self.get_chunked_item(idx, self.chunk_offsets, self.context_length) x = self.tensor[idx:idx + self.context_length] y = self.tensor[idx + self.context_length] if self.transform: x = self.transform(x) return x, y