diff --git a/CNN-model/dataset_loaders/Dataset.py b/CNN-model/dataset_loaders/Dataset.py index fbac1a6..228fa59 100644 --- a/CNN-model/dataset_loaders/Dataset.py +++ b/CNN-model/dataset_loaders/Dataset.py @@ -10,11 +10,14 @@ Author: Tibo De Peuter class Dataset(TorchDataset, ABC): """Abstract base class for datasets.""" @abstractmethod - def __init__(self, root: str, transform: Callable = None): + def __init__(self, name: str, root: str | None, transform: Callable = None): """ :param root: Relative path to the dataset root directory """ - self._root: str = join(curdir, 'data', root) + if root is None: + root = join(curdir, 'data') + + self._root = join(root, name) self.transform = transform self.dataset = None diff --git a/CNN-model/dataset_loaders/EnWik9.py b/CNN-model/dataset_loaders/EnWik9.py index bef57a1..5059940 100644 --- a/CNN-model/dataset_loaders/EnWik9.py +++ b/CNN-model/dataset_loaders/EnWik9.py @@ -1,18 +1,20 @@ -from datasets import load_dataset -from torch.utils.data import Dataset -import torch -from os.path import curdir, join from typing import Callable +import torch +from datasets import load_dataset + +from .Dataset import Dataset + class EnWik9DataSet(Dataset): - def __init__(self, root: str = "data", transform: Callable | None = None): - super().__init__() - self.transform = transform + """ + Hugging Face: https://huggingface.co/datasets/haukur/enwik9 + """ + def __init__(self, root: str | None = None, transform: Callable | None = None): + super().__init__('enwik9', root, transform) # HuggingFace dataset: string text - path = join(curdir, root) - data = load_dataset("haukur/enwik9", cache_dir=path, split="train") + data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train") # Extract raw text text = data["text"] @@ -31,7 +33,7 @@ class EnWik9DataSet(Dataset): def __getitem__(self, idx): # context window - x = self.data[idx : idx + self.context_length] + x = self.data[idx: idx + self.context_length] # next byte target y = self.data[idx + self.context_length] @@ -40,4 +42,3 @@ class EnWik9DataSet(Dataset): x = self.transform(x) return x, y - diff --git a/CNN-model/dataset_loaders/LoremIpsumDataset.py b/CNN-model/dataset_loaders/LoremIpsumDataset.py index 6ea0a85..0fad99d 100644 --- a/CNN-model/dataset_loaders/LoremIpsumDataset.py +++ b/CNN-model/dataset_loaders/LoremIpsumDataset.py @@ -1,21 +1,19 @@ from typing import Callable import torch -from os.path import curdir, join from lorem.text import TextLorem + from .Dataset import Dataset class LoremIpsumDataset(Dataset): - def __init__(self, root: str = "data", transform: Callable = None): - super().__init__(root, transform) + def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512): + super().__init__('lorem_ipsum', root, transform) # Generate text and convert to bytes _lorem = TextLorem() - _text = ' '.join(_lorem._word() for _ in range(512)) + _text = ' '.join(_lorem._word() for _ in range(size)) - path = join(curdir, "data") - self._root = path # Convert text to bytes (UTF-8 encoded) self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.context_length = 128 diff --git a/CNN-model/dataset_loaders/__init__.py b/CNN-model/dataset_loaders/__init__.py index 58336a2..63f124d 100644 --- a/CNN-model/dataset_loaders/__init__.py +++ b/CNN-model/dataset_loaders/__init__.py @@ -1,3 +1,8 @@ +from .Dataset import Dataset from .EnWik9 import EnWik9DataSet from .LoremIpsumDataset import LoremIpsumDataset -from .Dataset import Dataset \ No newline at end of file + +dataset_called: dict[str, type[Dataset]] = { + 'enwik9': EnWik9DataSet, + 'lorem_ipsum': LoremIpsumDataset +} diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index 530122e..d9bf757 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -4,61 +4,61 @@ from math import ceil import torch from torch.utils.data import DataLoader -from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset +from dataset_loaders import dataset_called from trainers import OptunaTrainer, Trainer, FullTrainer BATCH_SIZE = 64 -if torch.cuda.is_available(): - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" +if torch.accelerator.is_available(): + DEVICE = torch.accelerator.current_accelerator().type else: DEVICE = "cpu" # hyper parameters context_length = 128 -if __name__ == "__main__": - print(f"Running on device: {DEVICE}...") - parser = ArgumentParser() - parser.add_argument("--method", choices=["optuna", "train"], required=True) - parser.add_argument("--model-path", type=str, required=False) - args = parser.parse_args() +print(f"Running on device: {DEVICE}...") +parser = ArgumentParser() +parser.add_argument("--method", choices=["optuna", "train"], required=True) +parser.add_argument("--model-path", type=str, required=False) - print("Loading in the dataset...") - if args.method == "train": - dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) - elif args.method == "optuna": - dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE)) - else: - raise ValueError(f"Unknown method: {args.method}") +parser.add_argument_group("Data", "Data files or dataset to use") +parser.add_argument("--data-root", type=str, required=False) +parser.add_argument("dataset") +args = parser.parse_args() - dataset_length = len(dataset) - print(f"Dataset size = {dataset_length}") +print("Loading in the dataset...") +if args.dataset in dataset_called: + dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE)) +else: + # TODO Allow to import arbitrary files + raise NotImplementedError(f"Importing external datasets is not implemented yet") - training_size = ceil(0.8 * dataset_length) +dataset_length = len(dataset) +print(f"Dataset size = {dataset_length}") - print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}") +training_size = ceil(0.8 * dataset_length) - train_set, validate_set = torch.utils.data.random_split(dataset, - [training_size, dataset_length - training_size]) - training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) - validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) - loss_fn = torch.nn.CrossEntropyLoss() +print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}") - model = None - if args.model_path is not None: - print("Loading the model...") - model = torch.load(args.model_path) +train_set, validate_set = torch.utils.data.random_split(dataset, + [training_size, dataset_length - training_size]) +training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) +validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) +loss_fn = torch.nn.CrossEntropyLoss() - trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() +model = None +if args.model_path is not None: + print("Loading the model...") + model = torch.load(args.model_path) - trainer.execute( - model=model, - train_loader=training_loader, - validation_loader=validation_loader, - loss_fn=loss_fn, - n_epochs=200, - device=DEVICE - ) +trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() + +trainer.execute( + model=model, + train_loader=training_loader, + validation_loader=validation_loader, + loss_fn=loss_fn, + n_epochs=200, + device=DEVICE +)