feat: Choose dataset with options

This commit is contained in:
Tibo De Peuter 2025-11-30 20:19:39 +01:00
parent 20bdd4f566
commit 81c767371e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
5 changed files with 67 additions and 60 deletions

View file

@ -10,11 +10,14 @@ Author: Tibo De Peuter
class Dataset(TorchDataset, ABC): class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets.""" """Abstract base class for datasets."""
@abstractmethod @abstractmethod
def __init__(self, root: str, transform: Callable = None): def __init__(self, name: str, root: str | None, transform: Callable = None):
""" """
:param root: Relative path to the dataset root directory :param root: Relative path to the dataset root directory
""" """
self._root: str = join(curdir, 'data', root) if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.transform = transform self.transform = transform
self.dataset = None self.dataset = None

View file

@ -1,18 +1,20 @@
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
from os.path import curdir, join
from typing import Callable from typing import Callable
import torch
from datasets import load_dataset
from .Dataset import Dataset
class EnWik9DataSet(Dataset): class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable | None = None): """
super().__init__() Hugging Face: https://huggingface.co/datasets/haukur/enwik9
self.transform = transform """
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text # HuggingFace dataset: string text
path = join(curdir, root) data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
# Extract raw text # Extract raw text
text = data["text"] text = data["text"]
@ -31,7 +33,7 @@ class EnWik9DataSet(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# context window # context window
x = self.data[idx : idx + self.context_length] x = self.data[idx: idx + self.context_length]
# next byte target # next byte target
y = self.data[idx + self.context_length] y = self.data[idx + self.context_length]
@ -40,4 +42,3 @@ class EnWik9DataSet(Dataset):
x = self.transform(x) x = self.transform(x)
return x, y return x, y

View file

@ -1,21 +1,19 @@
from typing import Callable from typing import Callable
import torch import torch
from os.path import curdir, join
from lorem.text import TextLorem from lorem.text import TextLorem
from .Dataset import Dataset from .Dataset import Dataset
class LoremIpsumDataset(Dataset): class LoremIpsumDataset(Dataset):
def __init__(self, root: str = "data", transform: Callable = None): def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512):
super().__init__(root, transform) super().__init__('lorem_ipsum', root, transform)
# Generate text and convert to bytes # Generate text and convert to bytes
_lorem = TextLorem() _lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(512)) _text = ' '.join(_lorem._word() for _ in range(size))
path = join(curdir, "data")
self._root = path
# Convert text to bytes (UTF-8 encoded) # Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
self.context_length = 128 self.context_length = 128

View file

@ -1,3 +1,8 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet from .EnWik9 import EnWik9DataSet
from .LoremIpsumDataset import LoremIpsumDataset from .LoremIpsumDataset import LoremIpsumDataset
from .Dataset import Dataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset
}

View file

@ -4,61 +4,61 @@ from math import ceil
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset from dataset_loaders import dataset_called
from trainers import OptunaTrainer, Trainer, FullTrainer from trainers import OptunaTrainer, Trainer, FullTrainer
BATCH_SIZE = 64 BATCH_SIZE = 64
if torch.cuda.is_available(): if torch.accelerator.is_available():
DEVICE = "cuda" DEVICE = torch.accelerator.current_accelerator().type
elif torch.backends.mps.is_available():
DEVICE = "mps"
else: else:
DEVICE = "cpu" DEVICE = "cpu"
# hyper parameters # hyper parameters
context_length = 128 context_length = 128
if __name__ == "__main__": print(f"Running on device: {DEVICE}...")
print(f"Running on device: {DEVICE}...") parser = ArgumentParser()
parser = ArgumentParser() parser.add_argument("--method", choices=["optuna", "train"], required=True)
parser.add_argument("--method", choices=["optuna", "train"], required=True) parser.add_argument("--model-path", type=str, required=False)
parser.add_argument("--model-path", type=str, required=False)
args = parser.parse_args()
print("Loading in the dataset...") parser.add_argument_group("Data", "Data files or dataset to use")
if args.method == "train": parser.add_argument("--data-root", type=str, required=False)
dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) parser.add_argument("dataset")
elif args.method == "optuna": args = parser.parse_args()
dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE))
else:
raise ValueError(f"Unknown method: {args.method}")
dataset_length = len(dataset) print("Loading in the dataset...")
print(f"Dataset size = {dataset_length}") if args.dataset in dataset_called:
dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE))
else:
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
training_size = ceil(0.8 * dataset_length) dataset_length = len(dataset)
print(f"Dataset size = {dataset_length}")
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}") training_size = ceil(0.8 * dataset_length)
train_set, validate_set = torch.utils.data.random_split(dataset, print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
[training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None train_set, validate_set = torch.utils.data.random_split(dataset,
if args.model_path is not None: [training_size, dataset_length - training_size])
print("Loading the model...") training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
model = torch.load(args.model_path) validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() model = None
if args.model_path is not None:
print("Loading the model...")
model = torch.load(args.model_path)
trainer.execute( trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
model=model,
train_loader=training_loader, trainer.execute(
validation_loader=validation_loader, model=model,
loss_fn=loss_fn, train_loader=training_loader,
n_epochs=200, validation_loader=validation_loader,
device=DEVICE loss_fn=loss_fn,
) n_epochs=200,
device=DEVICE
)