feat: Human genome dataset

This commit is contained in:
Tibo De Peuter 2025-11-30 21:59:13 +01:00
parent b74ae7083a
commit 849bcd7b77
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
2 changed files with 47 additions and 1 deletions

View file

@ -0,0 +1,44 @@
from typing import Callable
import torch
from datasets import load_dataset
from torch import Tensor
from .Dataset import Dataset
class OpenGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
:param split Either 'train', 'test' or 'validation'
:param stage Either 'sample', 'stage1' or 'stage2'.
'sample' only provides a 'validation' split
"""
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
stage: str = 'stage2'):
super().__init__('open_genome', root, transform)
data = load_dataset("LongSafari/open-genome", stage)
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
return len(self.data) - self.context_length
def __getitem__(self, item):
x = self.data[item: item + self.context_length]
y = self.data[item + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -1,8 +1,10 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .OpenGenomeDataset import OpenGenomeDataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset
'lorem_ipsum': LoremIpsumDataset,
'opengenome': OpenGenomeDataset
}