From 849bcd7b7729d966198fd63558c277233ca3efb4 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Sun, 30 Nov 2025 21:59:13 +0100 Subject: [PATCH] feat: Human genome dataset --- .../dataset_loaders/OpenGenomeDataset.py | 44 +++++++++++++++++++ CNN-model/dataset_loaders/__init__.py | 4 +- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 CNN-model/dataset_loaders/OpenGenomeDataset.py diff --git a/CNN-model/dataset_loaders/OpenGenomeDataset.py b/CNN-model/dataset_loaders/OpenGenomeDataset.py new file mode 100644 index 0000000..585a799 --- /dev/null +++ b/CNN-model/dataset_loaders/OpenGenomeDataset.py @@ -0,0 +1,44 @@ +from typing import Callable + +import torch +from datasets import load_dataset +from torch import Tensor + +from .Dataset import Dataset + + +class OpenGenomeDataset(Dataset): + """ + Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome + + :param split Either 'train', 'test' or 'validation' + :param stage Either 'sample', 'stage1' or 'stage2'. + 'sample' only provides a 'validation' split + """ + + def __init__(self, + root: str | None = None, + split: str = 'train', + transform: Callable = None, + stage: str = 'stage2'): + super().__init__('open_genome', root, transform) + + data = load_dataset("LongSafari/open-genome", stage) + self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace') + + self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long) + + # Model uses fixed 128-length context + self.context_length = 128 + + def __len__(self): + return len(self.data) - self.context_length + + def __getitem__(self, item): + x = self.data[item: item + self.context_length] + y = self.data[item + self.context_length] + + if self.transform: + x = self.transform(x) + + return x, y diff --git a/CNN-model/dataset_loaders/__init__.py b/CNN-model/dataset_loaders/__init__.py index 63f124d..f23312c 100644 --- a/CNN-model/dataset_loaders/__init__.py +++ b/CNN-model/dataset_loaders/__init__.py @@ -1,8 +1,10 @@ from .Dataset import Dataset from .EnWik9 import EnWik9DataSet from .LoremIpsumDataset import LoremIpsumDataset +from .OpenGenomeDataset import OpenGenomeDataset dataset_called: dict[str, type[Dataset]] = { 'enwik9': EnWik9DataSet, - 'lorem_ipsum': LoremIpsumDataset + 'lorem_ipsum': LoremIpsumDataset, + 'opengenome': OpenGenomeDataset }