45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
from typing import Callable
|
|
|
|
from datasets import load_dataset
|
|
|
|
from .Dataset import Dataset
|
|
|
|
|
|
class HumanReferenceGenomeDataset(Dataset):
|
|
"""
|
|
Hugging Face: https://huggingface.co/datasets/InstaDeepAI/human_reference_genome
|
|
|
|
:param split: 'train' | 'validation' | 'test'
|
|
:param config: '6kbp' | '12kbp' (chunk length in the HF builder config)
|
|
"""
|
|
|
|
def __init__(self,
|
|
root: str | None = None,
|
|
split: str = "train",
|
|
transform: Callable = None,
|
|
size: int = -1,
|
|
context_length: int = 1024,
|
|
config: str = "6kbp",
|
|
):
|
|
super().__init__("human_reference_genome", root, split, transform, size, context_length)
|
|
|
|
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
|
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
|
cache_dir=self.root, trust_remote_code=True)
|
|
self.data = data["sequence"]
|
|
|
|
self.process_data()
|
|
|
|
print("Done initializing dataset")
|
|
|
|
def __len__(self):
|
|
return self.chunk_offsets[-1] - self.context_length
|
|
|
|
def __getitem__(self, idx):
|
|
x = self.tensor[idx: idx + self.context_length]
|
|
y = self.tensor[idx + self.context_length]
|
|
|
|
if self.transform:
|
|
x = self.transform(x)
|
|
|
|
return x, y
|