From d2e6d17f55697cd0e4fd87b085507e6156895b36 Mon Sep 17 00:00:00 2001 From: RobinMeersman Date: Thu, 27 Nov 2025 19:26:59 +0100 Subject: [PATCH] feat: updates to datasets/-loaders --- CNN-model/dataset_loaders/Dataset.py | 26 ++++++++++++++ CNN-model/dataset_loaders/EnWik9.py | 25 +++++++++++++ .../dataset_loaders/LoremIpsumDataset.py | 35 +++++++++++++++++++ CNN-model/dataset_loaders/__init__.py | 3 ++ CNN-model/datasets/EnWik9.py | 11 ------ CNN-model/datasets/LoremIpsumDataset.py | 5 --- CNN-model/datasets/__init__.py | 2 -- CNN-model/main_cnn.py | 15 ++++---- CNN-model/trainers/FullTrainer.py | 6 ++-- CNN-model/trainers/OptunaTrainer.py | 6 ++-- CNN-model/trainers/__init__.py | 5 +-- 11 files changed, 105 insertions(+), 34 deletions(-) create mode 100644 CNN-model/dataset_loaders/Dataset.py create mode 100644 CNN-model/dataset_loaders/EnWik9.py create mode 100644 CNN-model/dataset_loaders/LoremIpsumDataset.py create mode 100644 CNN-model/dataset_loaders/__init__.py delete mode 100644 CNN-model/datasets/EnWik9.py delete mode 100644 CNN-model/datasets/LoremIpsumDataset.py delete mode 100644 CNN-model/datasets/__init__.py diff --git a/CNN-model/dataset_loaders/Dataset.py b/CNN-model/dataset_loaders/Dataset.py new file mode 100644 index 0000000..fbac1a6 --- /dev/null +++ b/CNN-model/dataset_loaders/Dataset.py @@ -0,0 +1,26 @@ +from abc import abstractmethod, ABC +from os.path import join, curdir +from typing import Callable + +from torch.utils.data import Dataset as TorchDataset + +""" +Author: Tibo De Peuter +""" +class Dataset(TorchDataset, ABC): + """Abstract base class for datasets.""" + @abstractmethod + def __init__(self, root: str, transform: Callable = None): + """ + :param root: Relative path to the dataset root directory + """ + self._root: str = join(curdir, 'data', root) + self.transform = transform + self.dataset = None + + @property + def root(self): + return self._root + + def __len__(self): + return len(self.dataset) \ No newline at end of file diff --git a/CNN-model/dataset_loaders/EnWik9.py b/CNN-model/dataset_loaders/EnWik9.py new file mode 100644 index 0000000..32698a8 --- /dev/null +++ b/CNN-model/dataset_loaders/EnWik9.py @@ -0,0 +1,25 @@ +from datasets import load_dataset +from os.path import curdir, join +from .Dataset import Dataset +from torch.utils.data import TensorDataset +from typing import Callable + + +class EnWik9DataSet(Dataset): + def __init__(self, root: str = "data", transform: Callable = None): + super().__init__(root, transform) + + path = join(curdir, root) + self._root = path + + data = load_dataset("haukur/enwik9", cache_dir=path, split="train") + text = data["text"] + self.dataset = TensorDataset(text) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + if self.transform is not None: + return self.transform(self.dataset[idx]) + return self.dataset[idx] diff --git a/CNN-model/dataset_loaders/LoremIpsumDataset.py b/CNN-model/dataset_loaders/LoremIpsumDataset.py new file mode 100644 index 0000000..342d1df --- /dev/null +++ b/CNN-model/dataset_loaders/LoremIpsumDataset.py @@ -0,0 +1,35 @@ +from typing import Callable + +import torch +from os.path import curdir, join +from lorem.text import TextLorem +from .Dataset import Dataset + + +class LoremIpsumDataset(Dataset): + def __init__(self, root: str = "data", transform: Callable = None): + super().__init__(root, transform) + + # Generate text and convert to bytes + _lorem = TextLorem() + _text = ' '.join(_lorem._word() for _ in range(512)) + + path = join(curdir, "data") + self._root = path + # Convert text to bytes (UTF-8 encoded) + self.dataset = torch.tensor([ord(c) for c in list(_text)], dtype=torch.long) + + sequence_count = self.dataset.shape[0] // 128 # how many vectors of 128 elements can we make + self.dataset = self.dataset[:sequence_count * 128] + self.dataset = self.dataset.view(-1, 128) + + print(self.dataset.shape) + + def __len__(self): + # Number of possible sequences of length sequence_length + return self.dataset.size(0) + + def __getitem__(self, idx): + if self.transform is not None: + return self.transform(self.dataset[idx]) + return self.dataset[idx] diff --git a/CNN-model/dataset_loaders/__init__.py b/CNN-model/dataset_loaders/__init__.py new file mode 100644 index 0000000..58336a2 --- /dev/null +++ b/CNN-model/dataset_loaders/__init__.py @@ -0,0 +1,3 @@ +from .EnWik9 import EnWik9DataSet +from .LoremIpsumDataset import LoremIpsumDataset +from .Dataset import Dataset \ No newline at end of file diff --git a/CNN-model/datasets/EnWik9.py b/CNN-model/datasets/EnWik9.py deleted file mode 100644 index 6d56f52..0000000 --- a/CNN-model/datasets/EnWik9.py +++ /dev/null @@ -1,11 +0,0 @@ -from datasets import load_dataset -from os.path import curdir, join - -class EnWik9DataSet: - def __init__(self): - path = join(curdir, "data") - self.data = load_dataset("haukur/enwik9", cache_dir=path, split="train") - - - def __len__(self): - return len(self.data) \ No newline at end of file diff --git a/CNN-model/datasets/LoremIpsumDataset.py b/CNN-model/datasets/LoremIpsumDataset.py deleted file mode 100644 index e9440e8..0000000 --- a/CNN-model/datasets/LoremIpsumDataset.py +++ /dev/null @@ -1,5 +0,0 @@ -import lorem - -class LoremIpsumDataset: - def __init__(self): - self.data = lorem.text(paragraphs=100) \ No newline at end of file diff --git a/CNN-model/datasets/__init__.py b/CNN-model/datasets/__init__.py deleted file mode 100644 index d525c0f..0000000 --- a/CNN-model/datasets/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from EnWik9 import EnWik9DataSet -from LoremIpsumDataset import LoremIpsumDataset \ No newline at end of file diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index a8d2b54..89bb70e 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -2,10 +2,10 @@ from argparse import ArgumentParser from math import ceil import torch -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader -from datasets import EnWik9DataSet, LoremIpsumDataset -from trainers import OptunaTrainer, Trainer +from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset +from trainers import OptunaTrainer, Trainer, FullTrainer BATCH_SIZE = 64 DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" @@ -21,9 +21,9 @@ if __name__ == "__main__": args = parser.parse_args() if args.method == "train": - dataset = EnWik9DataSet() + dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) elif args.method == "optuna": - dataset = LoremIpsumDataset() + dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE)) else: raise ValueError(f"Unknown method: {args.method}") @@ -31,9 +31,8 @@ if __name__ == "__main__": training_size = ceil(0.8 * dataset_length) print(f"training set size = {training_size}, validation set size {dataset_length - training_size}") - data = dataset.data["text"] - train_set, validate_set = torch.utils.data.random_split(TensorDataset(data), + train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size]) training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) @@ -43,7 +42,7 @@ if __name__ == "__main__": if args.model_path is not None: model = torch.load(args.model_path) - trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None + trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() trainer.execute( model=model, diff --git a/CNN-model/trainers/FullTrainer.py b/CNN-model/trainers/FullTrainer.py index 6717f1d..fecfe90 100644 --- a/CNN-model/trainers/FullTrainer.py +++ b/CNN-model/trainers/FullTrainer.py @@ -4,9 +4,9 @@ import torch from torch import nn as nn from torch.utils.data import DataLoader -from trainer import Trainer -from train import train -from ..utils import print_losses +from .trainer import Trainer +from .train import train +from utils import print_losses class FullTrainer(Trainer): def execute( diff --git a/CNN-model/trainers/OptunaTrainer.py b/CNN-model/trainers/OptunaTrainer.py index 66850b2..e2d3d7a 100644 --- a/CNN-model/trainers/OptunaTrainer.py +++ b/CNN-model/trainers/OptunaTrainer.py @@ -6,9 +6,9 @@ import torch from torch import nn as nn from torch.utils.data import DataLoader -from trainer import Trainer -from ..model.cnn import CNNPredictor -from train import train +from .trainer import Trainer +from model.cnn import CNNPredictor +from .train import train def create_model(trial: tr.Trial, vocab_size: int = 256): diff --git a/CNN-model/trainers/__init__.py b/CNN-model/trainers/__init__.py index 05c38f6..1783b4d 100644 --- a/CNN-model/trainers/__init__.py +++ b/CNN-model/trainers/__init__.py @@ -1,2 +1,3 @@ -from OptunaTrainer import OptunaTrainer -from trainer import Trainer \ No newline at end of file +from .OptunaTrainer import OptunaTrainer +from .FullTrainer import FullTrainer +from .trainer import Trainer \ No newline at end of file