feat: Human genome dataset
This commit is contained in:
parent
b74ae7083a
commit
849bcd7b77
2 changed files with 47 additions and 1 deletions
44
CNN-model/dataset_loaders/OpenGenomeDataset.py
Normal file
44
CNN-model/dataset_loaders/OpenGenomeDataset.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .Dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class OpenGenomeDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
|
||||||
|
|
||||||
|
:param split Either 'train', 'test' or 'validation'
|
||||||
|
:param stage Either 'sample', 'stage1' or 'stage2'.
|
||||||
|
'sample' only provides a 'validation' split
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
root: str | None = None,
|
||||||
|
split: str = 'train',
|
||||||
|
transform: Callable = None,
|
||||||
|
stage: str = 'stage2'):
|
||||||
|
super().__init__('open_genome', root, transform)
|
||||||
|
|
||||||
|
data = load_dataset("LongSafari/open-genome", stage)
|
||||||
|
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
|
||||||
|
|
||||||
|
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
|
||||||
|
|
||||||
|
# Model uses fixed 128-length context
|
||||||
|
self.context_length = 128
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data) - self.context_length
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
x = self.data[item: item + self.context_length]
|
||||||
|
y = self.data[item + self.context_length]
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
x = self.transform(x)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from .Dataset import Dataset
|
from .Dataset import Dataset
|
||||||
from .EnWik9 import EnWik9DataSet
|
from .EnWik9 import EnWik9DataSet
|
||||||
from .LoremIpsumDataset import LoremIpsumDataset
|
from .LoremIpsumDataset import LoremIpsumDataset
|
||||||
|
from .OpenGenomeDataset import OpenGenomeDataset
|
||||||
|
|
||||||
dataset_called: dict[str, type[Dataset]] = {
|
dataset_called: dict[str, type[Dataset]] = {
|
||||||
'enwik9': EnWik9DataSet,
|
'enwik9': EnWik9DataSet,
|
||||||
'lorem_ipsum': LoremIpsumDataset
|
'lorem_ipsum': LoremIpsumDataset,
|
||||||
|
'opengenome': OpenGenomeDataset
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Reference in a new issue