feat: Choose dataset with options

This commit is contained in:
Tibo De Peuter 2025-11-30 20:19:39 +01:00
parent 20bdd4f566
commit 81c767371e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
5 changed files with 67 additions and 60 deletions

View file

@ -10,11 +10,14 @@ Author: Tibo De Peuter
class Dataset(TorchDataset, ABC): class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets.""" """Abstract base class for datasets."""
@abstractmethod @abstractmethod
def __init__(self, root: str, transform: Callable = None): def __init__(self, name: str, root: str | None, transform: Callable = None):
""" """
:param root: Relative path to the dataset root directory :param root: Relative path to the dataset root directory
""" """
self._root: str = join(curdir, 'data', root) if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.transform = transform self.transform = transform
self.dataset = None self.dataset = None

View file

@ -1,18 +1,20 @@
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
from os.path import curdir, join
from typing import Callable from typing import Callable
import torch
from datasets import load_dataset
from .Dataset import Dataset
class EnWik9DataSet(Dataset): class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable | None = None): """
super().__init__() Hugging Face: https://huggingface.co/datasets/haukur/enwik9
self.transform = transform """
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text # HuggingFace dataset: string text
path = join(curdir, root) data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
# Extract raw text # Extract raw text
text = data["text"] text = data["text"]
@ -40,4 +42,3 @@ class EnWik9DataSet(Dataset):
x = self.transform(x) x = self.transform(x)
return x, y return x, y

View file

@ -1,21 +1,19 @@
from typing import Callable from typing import Callable
import torch import torch
from os.path import curdir, join
from lorem.text import TextLorem from lorem.text import TextLorem
from .Dataset import Dataset from .Dataset import Dataset
class LoremIpsumDataset(Dataset): class LoremIpsumDataset(Dataset):
def __init__(self, root: str = "data", transform: Callable = None): def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512):
super().__init__(root, transform) super().__init__('lorem_ipsum', root, transform)
# Generate text and convert to bytes # Generate text and convert to bytes
_lorem = TextLorem() _lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(512)) _text = ' '.join(_lorem._word() for _ in range(size))
path = join(curdir, "data")
self._root = path
# Convert text to bytes (UTF-8 encoded) # Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
self.context_length = 128 self.context_length = 128

View file

@ -1,3 +1,8 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset from .LoremIpsumDataset import LoremIpsumDataset
from .Dataset import Dataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset
}

View file

@ -4,35 +4,35 @@ from math import ceil
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset from dataset_loaders import dataset_called
from trainers import OptunaTrainer, Trainer, FullTrainer from trainers import OptunaTrainer, Trainer, FullTrainer
BATCH_SIZE = 64 BATCH_SIZE = 64
if torch.cuda.is_available(): if torch.accelerator.is_available():
DEVICE = "cuda" DEVICE = torch.accelerator.current_accelerator().type
elif torch.backends.mps.is_available():
DEVICE = "mps"
else: else:
DEVICE = "cpu" DEVICE = "cpu"
# hyper parameters # hyper parameters
context_length = 128 context_length = 128
if __name__ == "__main__":
print(f"Running on device: {DEVICE}...") print(f"Running on device: {DEVICE}...")
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--method", choices=["optuna", "train"], required=True) parser.add_argument("--method", choices=["optuna", "train"], required=True)
parser.add_argument("--model-path", type=str, required=False) parser.add_argument("--model-path", type=str, required=False)
parser.add_argument_group("Data", "Data files or dataset to use")
parser.add_argument("--data-root", type=str, required=False)
parser.add_argument("dataset")
args = parser.parse_args() args = parser.parse_args()
print("Loading in the dataset...") print("Loading in the dataset...")
if args.method == "train": if args.dataset in dataset_called:
dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE))
elif args.method == "optuna":
dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE))
else: else:
raise ValueError(f"Unknown method: {args.method}") # TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
dataset_length = len(dataset) dataset_length = len(dataset)
print(f"Dataset size = {dataset_length}") print(f"Dataset size = {dataset_length}")