Streamline datasets

This commit is contained in:
Tibo De Peuter 2025-12-04 23:13:16 +01:00
parent 849bcd7b77
commit befb1a96a5
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 222 additions and 64 deletions

View file

@ -2,28 +2,114 @@ from abc import abstractmethod, ABC
from os.path import join, curdir
from typing import Callable
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
"""
Author: Tibo De Peuter
"""
class Dataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
@abstractmethod
def __init__(self, name: str, root: str | None, transform: Callable = None):
def __init__(self,
name: str,
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
):
"""
:param root: Relative path to the dataset root directory
:param root: Path to the dataset root directory
:param split: The dataset split, e.g. 'train', 'validation', 'test'
:param size: Override the maximum size of the dataset, useful for debugging
"""
if root is None:
root = join(curdir, 'data')
self._root = join(root, name)
self.split = split
self.transform = transform
self.dataset = None
self.size = size
self.data = None
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])
@property
def root(self):
return self._root
def __len__(self):
return len(self.dataset)
return len(self.dataset)
def process_data(self):
if self.size == -1:
# Just use the whole dataset
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
else:
# Use only partition, calculate offsets
self.chunk_offsets = self.get_offsets()
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
def get_offsets(self):
"""
Calculate for each chunk how many bytes came before it
"""
offsets = [0]
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
idx = len(offsets) - 1
offsets.append(offsets[idx] + len(self.data[idx]))
print(offsets)
return offsets
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
item = ''
# Determine first chunk in which item is located
chunk_idx = 0
while idx >= offsets[chunk_idx]:
chunk_idx += 1
chunk_idx -= 1
# Extract item from chunks
chunk = str(self.data[chunk_idx])
chunk_start = offsets[chunk_idx]
chunk_item_start = idx - chunk_start
item_len_remaining = context_length + 1
assert len(item) + item_len_remaining == context_length + 1
while chunk_item_start + item_len_remaining > len(chunk):
adding_now_len = len(chunk) - chunk_item_start
item += chunk[chunk_item_start:]
chunk_idx += 1
chunk = str(self.data[chunk_idx])
chunk_item_start = 0
item_len_remaining -= adding_now_len
assert len(item) + item_len_remaining == context_length + 1
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
# Transform to tensor
data = ''.join(item).encode('utf-8', errors='replace')
t = torch.tensor(list(data), dtype=torch.long)
x, y = t[:-1], t[-1]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -1,7 +1,7 @@
from math import ceil
from typing import Callable
import torch
from datasets import load_dataset
from datasets import load_dataset, Features, Value
from .Dataset import Dataset
@ -10,33 +10,48 @@ class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text
data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
):
super().__init__('enwik9', root, split, transform, size)
# Extract raw text
text = data["text"]
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
# Don't pass split here, dataset only contains training
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
# number of sliding windows
return len(self.data) - self.context_length
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
# context window
x = self.data[idx: idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform:
x = self.transform(x)

View file

@ -1,32 +1,63 @@
from math import ceil
from typing import Callable
import torch
from lorem.text import TextLorem
from tqdm import tqdm
from .Dataset import Dataset
class LoremIpsumDataset(Dataset):
def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512):
super().__init__('lorem_ipsum', root, transform)
def __init__(self,
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
):
super().__init__('lorem_ipsum', root, split, transform, size)
# Generate text and convert to bytes
_lorem = TextLorem()
_text = ' '.join(_lorem._word() for _ in range(size))
# Convert text to bytes (UTF-8 encoded)
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)
if self.split == 'train':
self.start_byte = 0
self.end_byte = split_point
elif self.split == 'validation':
self.start_byte = split_point
self.end_byte = self.chunk_offsets[-1]
else:
raise ValueError("split must be 'train' or 'validation'")
print("Done initializing dataset")
def __len__(self):
# Number of possible sequences of length sequence_length
return self.dataset.size(0) - self.context_length
return self.end_byte - self.start_byte - self.context_length
def __getitem__(self, idx):
x = self.dataset[idx: idx + self.context_length]
y = self.dataset[idx + self.context_length]
# Get sequence of characters
# x_str = self.text[idx: idx + self.context_length]
# y_char = self.text[idx + self.context_length]
#
# # Convert to tensors
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
#
# if self.transform is not None:
# x = self.transform(x)
#
# return x, y
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
y = self.tensor[self.start_byte + idx + self.context_length]
if self.transform is not None:
if self.transform:
x = self.transform(x)
return x, y

View file

@ -1,8 +1,6 @@
from typing import Callable
import torch
from datasets import load_dataset
from torch import Tensor
from datasets import load_dataset, Value, Features
from .Dataset import Dataset
@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable = None,
stage: str = 'stage2'):
super().__init__('open_genome', root, transform)
size: int = -1,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
data = load_dataset("LongSafari/open-genome", stage)
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
return len(self.data) - self.context_length
self.process_data()
def __getitem__(self, item):
x = self.data[item: item + self.context_length]
y = self.data[item + self.context_length]
print("Done initializing dataset")
def __len__(self):
# return len(self.data) - self.context_length
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
x = self.tensor[idx:idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)

View file

@ -10,6 +10,8 @@ from trainers import OptunaTrainer, Trainer, FullTrainer
def parse_arguments():
parser = ArgumentParser(prog="NeuralCompression")
parser.add_argument("--debug", "-d", action="store_true", required=False,
help="Enable debug mode: smaller datasets, more information")
parser.add_argument("--verbose", "-v", action="store_true", required=False,
help="Enable verbose mode")
@ -18,7 +20,7 @@ def parse_arguments():
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model-path", type=str, required=True,
modelparser.add_argument("--model-path", type=str, required=False,
help="Path to the model to load/save")
fileparser = ArgumentParser(add_help=False)
@ -33,6 +35,8 @@ def parse_arguments():
help="Only fetch the dataset, then exit")
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
help="Method to use for training")
# TODO
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
@ -44,7 +48,7 @@ def parse_arguments():
def main():
BATCH_SIZE = 64
BATCH_SIZE = 2
# hyper parameters
context_length = 128
@ -57,9 +61,18 @@ def main():
DEVICE = "cpu"
print(f"Running on device: {DEVICE}...")
dataset_common_args = {
'root': args.data_root,
'transform': lambda x: x.to(DEVICE)
}
if args.debug:
dataset_common_args['size'] = 2**10
print("Loading in the dataset...")
if args.dataset in dataset_called:
dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE))
training_set = dataset_called[args.dataset](split='train', **dataset_common_args)
validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args)
else:
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
@ -68,16 +81,10 @@ def main():
# TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0)
dataset_length = len(dataset)
print(f"Dataset size = {dataset_length}")
training_size = ceil(0.8 * dataset_length)
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
@ -85,8 +92,9 @@ def main():
print("Loading the model...")
model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
print("Training")
trainer.execute(
model=model,
train_loader=training_loader,

Binary file not shown.

View file

@ -35,6 +35,11 @@ def objective_function(
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
@ -47,7 +52,7 @@ class OptunaTrainer(Trainer):
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
n_trials=self.n_trials
)
best_params = study.best_trial.params

View file

@ -1,5 +1,11 @@
# neural compression
Example usage:
```shell
python main_cnn.py --debug train --dataset enwik9 --method optuna
```
## Running locally
```