Streamline datasets
This commit is contained in:
parent
849bcd7b77
commit
befb1a96a5
8 changed files with 222 additions and 64 deletions
|
|
@ -1,8 +1,6 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch import Tensor
|
||||
from datasets import load_dataset, Value, Features
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
|
@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
stage: str = 'stage2'):
|
||||
super().__init__('open_genome', root, transform)
|
||||
size: int = -1,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
|
||||
data = load_dataset("LongSafari/open-genome", stage)
|
||||
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
|
||||
|
||||
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
||||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) - self.context_length
|
||||
self.process_data()
|
||||
|
||||
def __getitem__(self, item):
|
||||
x = self.data[item: item + self.context_length]
|
||||
y = self.data[item + self.context_length]
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# return len(self.data) - self.context_length
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[idx:idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
|
|
|||
Reference in a new issue