Streamline datasets

This commit is contained in:
Tibo De Peuter 2025-12-04 23:13:16 +01:00
parent 849bcd7b77
commit befb1a96a5
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 222 additions and 64 deletions

View file

@ -2,24 +2,44 @@ from abc import abstractmethod, ABC
from os.path import join, curdir from os.path import join, curdir
from typing import Callable from typing import Callable
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
""" """
Author: Tibo De Peuter Author: Tibo De Peuter
""" """
class Dataset(TorchDataset, ABC): class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets.""" """Abstract base class for datasets."""
@abstractmethod @abstractmethod
def __init__(self, name: str, root: str | None, transform: Callable = None): def __init__(self,
name: str,
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
):
""" """
:param root: Relative path to the dataset root directory :param root: Path to the dataset root directory
:param split: The dataset split, e.g. 'train', 'validation', 'test'
:param size: Override the maximum size of the dataset, useful for debugging
""" """
if root is None: if root is None:
root = join(curdir, 'data') root = join(curdir, 'data')
self._root = join(root, name) self._root = join(root, name)
self.split = split
self.transform = transform self.transform = transform
self.dataset = None self.size = size
self.data = None
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])
@property @property
def root(self): def root(self):
@ -27,3 +47,69 @@ class Dataset(TorchDataset, ABC):
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
def process_data(self):
if self.size == -1:
# Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
else:
# Use only partition, calculate offsets
self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
def get_offsets(self):
"""
Calculate for each chunk how many bytes came before it
"""
offsets = [0]
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
idx = len(offsets) - 1
offsets.append(offsets[idx] + len(self.data[idx]))
print(offsets)
return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
item = ''
# Determine first chunk in which item is located
chunk_idx = 0
while idx >= offsets[chunk_idx]:
chunk_idx += 1
chunk_idx -= 1
# Extract item from chunks
chunk = str(self.data[chunk_idx])
chunk_start = offsets[chunk_idx]
chunk_item_start = idx - chunk_start
item_len_remaining = context_length + 1
assert len(item) + item_len_remaining == context_length + 1
while chunk_item_start + item_len_remaining > len(chunk):
adding_now_len = len(chunk) - chunk_item_start
item += chunk[chunk_item_start:]
chunk_idx += 1
chunk = str(self.data[chunk_idx])
chunk_item_start = 0
item_len_remaining -= adding_now_len
assert len(item) + item_len_remaining == context_length + 1
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
# Transform to tensor
data = ''.join(item).encode('utf-8', errors='replace')
t = torch.tensor(list(data), dtype=torch.long)
x, y = t[:-1], t[-1]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -1,7 +1,7 @@
from math import ceil
from typing import Callable from typing import Callable
import torch from datasets import load_dataset, Features, Value
from datasets import load_dataset
from .Dataset import Dataset from .Dataset import Dataset
@ -10,33 +10,48 @@ class EnWik9DataSet(Dataset):
""" """
Hugging Face: https://huggingface.co/datasets/haukur/enwik9 Hugging Face: https://huggingface.co/datasets/haukur/enwik9
""" """
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text def __init__(self,
data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train") root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
):
super().__init__('enwik9', root, split, transform, size)
# Extract raw text print(f"Loading from HuggingFace")
text = data["text"] ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
# Convert text (Python string) → bytes → tensor of ints 0255 text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors self.data = text_chunks['text']
byte_data = "".join(text).encode("utf-8", errors="replace") self.size = size
self.data = torch.tensor(list(byte_data), dtype=torch.long)
# Model uses fixed 128-length context # Model uses fixed 128-length context
self.context_length = 128 self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self): def __len__(self):
# number of sliding windows return self.end_byte - self.start_byte - self.context_length
return len(self.data) - self.context_length
def __getitem__(self, idx): def __getitem__(self, idx):
# context window # return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.data[idx: idx + self.context_length] x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
if self.transform: if self.transform:
x = self.transform(x) x = self.transform(x)

View file

@ -1,32 +1,63 @@
from math import ceil
from typing import Callable from typing import Callable
import torch
from lorem.text import TextLorem from lorem.text import TextLorem
from tqdm import tqdm
from .Dataset import Dataset from .Dataset import Dataset
class LoremIpsumDataset(Dataset): class LoremIpsumDataset(Dataset):
def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512): def __init__(self,
super().__init__('lorem_ipsum', root, transform) root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
):
super().__init__('lorem_ipsum', root, split, transform, size)
# Generate text and convert to bytes
_lorem = TextLorem() _lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(size))
# Convert text to bytes (UTF-8 encoded) self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.size = size
self.context_length = 128 self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self): def __len__(self):
# Number of possible sequences of length sequence_length return self.end_byte - self.start_byte - self.context_length
return self.dataset.size(0) - self.context_length
def __getitem__(self, idx): def __getitem__(self, idx):
x = self.dataset[idx: idx + self.context_length] # Get sequence of characters
y = self.dataset[idx + self.context_length] # x_str = self.text[idx: idx + self.context_length]
# y_char = self.text[idx + self.context_length]
#
# # Convert to tensors
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
#
# if self.transform is not None:
# x = self.transform(x)
#
# return x, y
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform is not None: if self.transform:
x = self.transform(x) x = self.transform(x)
return x, y return x, y

View file

@ -1,8 +1,6 @@
from typing import Callable from typing import Callable
import torch from datasets import load_dataset, Value, Features
from datasets import load_dataset
from torch import Tensor
from .Dataset import Dataset from .Dataset import Dataset
@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset):
root: str | None = None, root: str | None = None,
split: str = 'train', split: str = 'train',
transform: Callable = None, transform: Callable = None,
stage: str = 'stage2'): size: int = -1,
super().__init__('open_genome', root, transform) stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
data = load_dataset("LongSafari/open-genome", stage) print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace') ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long) self.data = data['text']
self.size = size
# Model uses fixed 128-length context # Model uses fixed 128-length context
self.context_length = 128 self.context_length = 128
def __len__(self): self.process_data()
return len(self.data) - self.context_length
def __getitem__(self, item): print("Done initializing dataset")
x = self.data[item: item + self.context_length]
y = self.data[item + self.context_length] def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform: if self.transform:
x = self.transform(x) x = self.transform(x)

View file

@ -10,6 +10,8 @@ from trainers import OptunaTrainer, Trainer, FullTrainer
def parse_arguments(): def parse_arguments():
parser = ArgumentParser(prog="NeuralCompression") parser = ArgumentParser(prog="NeuralCompression")
parser.add_argument("--debug", "-d", action="store_true", required=False,
help="Enable debug mode: smaller datasets, more information")
parser.add_argument("--verbose", "-v", action="store_true", required=False, parser.add_argument("--verbose", "-v", action="store_true", required=False,
help="Enable verbose mode") help="Enable verbose mode")
@ -18,7 +20,7 @@ def parse_arguments():
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
modelparser = ArgumentParser(add_help=False) modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model-path", type=str, required=True, modelparser.add_argument("--model-path", type=str, required=False,
help="Path to the model to load/save") help="Path to the model to load/save")
fileparser = ArgumentParser(add_help=False) fileparser = ArgumentParser(add_help=False)
@ -33,6 +35,8 @@ def parse_arguments():
help="Only fetch the dataset, then exit") help="Only fetch the dataset, then exit")
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser]) train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
help="Method to use for training")
# TODO # TODO
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser]) compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
@ -44,7 +48,7 @@ def parse_arguments():
def main(): def main():
BATCH_SIZE = 64 BATCH_SIZE = 2
# hyper parameters # hyper parameters
context_length = 128 context_length = 128
@ -57,9 +61,18 @@ def main():
DEVICE = "cpu" DEVICE = "cpu"
print(f"Running on device: {DEVICE}...") print(f"Running on device: {DEVICE}...")
dataset_common_args = {
'root': args.data_root,
'transform': lambda x: x.to(DEVICE)
}
if args.debug:
dataset_common_args['size'] = 2**10
print("Loading in the dataset...") print("Loading in the dataset...")
if args.dataset in dataset_called: if args.dataset in dataset_called:
dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE)) training_set = dataset_called[args.dataset](split='train', **dataset_common_args)
validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args)
else: else:
# TODO Allow to import arbitrary files # TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet") raise NotImplementedError(f"Importing external datasets is not implemented yet")
@ -68,16 +81,10 @@ def main():
# TODO More to earlier in chain, because now everything is converted into tensors as well? # TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0) exit(0)
dataset_length = len(dataset) print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
print(f"Dataset size = {dataset_length}") training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
training_size = ceil(0.8 * dataset_length)
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss() loss_fn = torch.nn.CrossEntropyLoss()
model = None model = None
@ -85,8 +92,9 @@ def main():
print("Loading the model...") print("Loading the model...")
model = torch.load(args.model_path) model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
print("Training")
trainer.execute( trainer.execute(
model=model, model=model,
train_loader=training_loader, train_loader=training_loader,

Binary file not shown.

View file

@ -35,6 +35,11 @@ def objective_function(
class OptunaTrainer(Trainer): class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute( def execute(
self, self,
model: nn.Module | None, model: nn.Module | None,
@ -47,7 +52,7 @@ class OptunaTrainer(Trainer):
study = optuna.create_study(study_name="CNN network", direction="minimize") study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize( study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20 n_trials=self.n_trials
) )
best_params = study.best_trial.params best_params = study.best_trial.params

View file

@ -1,5 +1,11 @@
# neural compression # neural compression
Example usage:
```shell
python main_cnn.py --debug train --dataset enwik9 --method optuna
```
## Running locally ## Running locally
``` ```