feat: updates to datasets/-loaders

This commit is contained in:
RobinMeersman 2025-11-27 19:26:59 +01:00
parent ed44d5b283
commit d2e6d17f55
11 changed files with 105 additions and 34 deletions

View file

@ -0,0 +1,26 @@
from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
from torch.utils.data import Dataset as TorchDataset
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self, root: str, transform: Callable = None):
"""
:param root: Relative path to the dataset root directory
"""
self._root: str = join(curdir, 'data', root)
self.transform = transform
self.dataset = None
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)

View file

@ -0,0 +1,25 @@
from datasets import load_dataset
from os.path import curdir, join
from .Dataset import Dataset
from torch.utils.data import TensorDataset
from typing import Callable
class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable = None):
super().__init__(root, transform)
path = join(curdir, root)
self._root = path
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
text = data["text"]
self.dataset = TensorDataset(text)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if self.transform is not None:
return self.transform(self.dataset[idx])
return self.dataset[idx]

View file

@ -0,0 +1,35 @@
from typing import Callable
import torch
from os.path import curdir, join
from lorem.text import TextLorem
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self, root: str = "data", transform: Callable = None):
super().__init__(root, transform)
# Generate text and convert to bytes
_lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(512))
path = join(curdir, "data")
self._root = path
# Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) for c in list(_text)], dtype=torch.long)
sequence_count = self.dataset.shape[0] // 128 # how many vectors of 128 elements can we make
self.dataset = self.dataset[:sequence_count * 128]
self.dataset = self.dataset.view(-1, 128)
print(self.dataset.shape)
def __len__(self):
# Number of possible sequences of length sequence_length
return self.dataset.size(0)
def __getitem__(self, idx):
if self.transform is not None:
return self.transform(self.dataset[idx])
return self.dataset[idx]

View file

@ -0,0 +1,3 @@
from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset
from .Dataset import Dataset

View file

@ -1,11 +0,0 @@
from datasets import load_dataset
from os.path import curdir, join
class EnWik9DataSet:
def __init__(self):
path = join(curdir, "data")
self.data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
def __len__(self):
return len(self.data)

View file

@ -1,5 +0,0 @@
import lorem
class LoremIpsumDataset:
def __init__(self):
self.data = lorem.text(paragraphs=100)

View file

@ -1,2 +0,0 @@
from EnWik9 import EnWik9DataSet
from LoremIpsumDataset import LoremIpsumDataset

View file

@ -2,10 +2,10 @@ from argparse import ArgumentParser
from math import ceil
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader
from datasets import EnWik9DataSet, LoremIpsumDataset
from trainers import OptunaTrainer, Trainer
from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset
from trainers import OptunaTrainer, Trainer, FullTrainer
BATCH_SIZE = 64
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
@ -21,9 +21,9 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.method == "train":
dataset = EnWik9DataSet()
dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE))
elif args.method == "optuna":
dataset = LoremIpsumDataset()
dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE))
else:
raise ValueError(f"Unknown method: {args.method}")
@ -31,9 +31,8 @@ if __name__ == "__main__":
training_size = ceil(0.8 * dataset_length)
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
data = dataset.data["text"]
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
train_set, validate_set = torch.utils.data.random_split(dataset,
[training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
@ -43,7 +42,7 @@ if __name__ == "__main__":
if args.model_path is not None:
model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
trainer.execute(
model=model,

View file

@ -4,9 +4,9 @@ import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from train import train
from ..utils import print_losses
from .trainer import Trainer
from .train import train
from utils import print_losses
class FullTrainer(Trainer):
def execute(

View file

@ -6,9 +6,9 @@ import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from ..model.cnn import CNNPredictor
from train import train
from .trainer import Trainer
from model.cnn import CNNPredictor
from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):

View file

@ -1,2 +1,3 @@
from OptunaTrainer import OptunaTrainer
from trainer import Trainer
from .OptunaTrainer import OptunaTrainer
from .FullTrainer import FullTrainer
from .trainer import Trainer