feat: updates to datasets/-loaders

This commit is contained in:
RobinMeersman 2025-11-27 19:26:59 +01:00
parent ed44d5b283
commit d2e6d17f55
11 changed files with 105 additions and 34 deletions

View file

@ -0,0 +1,26 @@
from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
from torch.utils.data import Dataset as TorchDataset
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self, root: str, transform: Callable = None):
"""
:param root: Relative path to the dataset root directory
"""
self._root: str = join(curdir, 'data', root)
self.transform = transform
self.dataset = None
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)

View file

@ -0,0 +1,25 @@
from datasets import load_dataset
from os.path import curdir, join
from .Dataset import Dataset
from torch.utils.data import TensorDataset
from typing import Callable
class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable = None):
super().__init__(root, transform)
path = join(curdir, root)
self._root = path
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
text = data["text"]
self.dataset = TensorDataset(text)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if self.transform is not None:
return self.transform(self.dataset[idx])
return self.dataset[idx]

View file

@ -0,0 +1,35 @@
from typing import Callable
import torch
from os.path import curdir, join
from lorem.text import TextLorem
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self, root: str = "data", transform: Callable = None):
super().__init__(root, transform)
# Generate text and convert to bytes
_lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(512))
path = join(curdir, "data")
self._root = path
# Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) for c in list(_text)], dtype=torch.long)
sequence_count = self.dataset.shape[0] // 128 # how many vectors of 128 elements can we make
self.dataset = self.dataset[:sequence_count * 128]
self.dataset = self.dataset.view(-1, 128)
print(self.dataset.shape)
def __len__(self):
# Number of possible sequences of length sequence_length
return self.dataset.size(0)
def __getitem__(self, idx):
if self.transform is not None:
return self.transform(self.dataset[idx])
return self.dataset[idx]

View file

@ -0,0 +1,3 @@
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .Dataset import Dataset