Streamline datasets

This commit is contained in:
Tibo De Peuter 2025-12-04 23:13:16 +01:00
parent 849bcd7b77
commit befb1a96a5
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 222 additions and 64 deletions

View file

@ -1,8 +1,6 @@
from typing import Callable
import torch
from datasets import load_dataset
from torch import Tensor
from datasets import load_dataset, Value, Features
from .Dataset import Dataset
@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable = None,
stage: str = 'stage2'):
super().__init__('open_genome', root, transform)
size: int = -1,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
data = load_dataset("LongSafari/open-genome", stage)
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
return len(self.data) - self.context_length
self.process_data()
def __getitem__(self, item):
x = self.data[item: item + self.context_length]
y = self.data[item + self.context_length]
print("Done initializing dataset")
def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)