51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
from typing import Callable
|
|
|
|
from datasets import load_dataset, Value, Features
|
|
|
|
from .Dataset import Dataset
|
|
|
|
|
|
class OpenGenomeDataset(Dataset):
|
|
"""
|
|
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
|
|
|
|
:param split Either 'train', 'test' or 'validation'
|
|
:param stage Either 'sample', 'stage1' or 'stage2'.
|
|
'sample' only provides a 'validation' split
|
|
"""
|
|
|
|
def __init__(self,
|
|
root: str | None = None,
|
|
split: str = 'train',
|
|
transform: Callable = None,
|
|
size: int = -1,
|
|
stage: str = 'stage2'
|
|
):
|
|
super().__init__('open_genome', root, split, transform, size)
|
|
|
|
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
|
ft = Features({'text': Value('string')})
|
|
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
|
self.data = data['text']
|
|
self.size = size
|
|
|
|
# Model uses fixed 128-length context
|
|
self.context_length = 128
|
|
|
|
self.process_data()
|
|
|
|
print("Done initializing dataset")
|
|
|
|
def __len__(self):
|
|
# return len(self.data) - self.context_length
|
|
return self.chunk_offsets[-1] - self.context_length
|
|
|
|
def __getitem__(self, idx):
|
|
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
|
x = self.tensor[idx:idx + self.context_length]
|
|
y = self.tensor[idx + self.context_length]
|
|
|
|
if self.transform:
|
|
x = self.transform(x)
|
|
|
|
return x, y
|