Streamline datasets
This commit is contained in:
parent
849bcd7b77
commit
befb1a96a5
8 changed files with 222 additions and 64 deletions
|
|
@ -2,28 +2,114 @@ from abc import abstractmethod, ABC
|
|||
from os.path import join, curdir
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
Author: Tibo De Peuter
|
||||
"""
|
||||
|
||||
|
||||
class Dataset(TorchDataset, ABC):
|
||||
"""Abstract base class for datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, name: str, root: str | None, transform: Callable = None):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1
|
||||
):
|
||||
"""
|
||||
:param root: Relative path to the dataset root directory
|
||||
:param root: Path to the dataset root directory
|
||||
:param split: The dataset split, e.g. 'train', 'validation', 'test'
|
||||
:param size: Override the maximum size of the dataset, useful for debugging
|
||||
"""
|
||||
if root is None:
|
||||
root = join(curdir, 'data')
|
||||
|
||||
self._root = join(root, name)
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.dataset = None
|
||||
self.size = size
|
||||
self.data = None
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self._root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
return len(self.dataset)
|
||||
|
||||
def process_data(self):
|
||||
if self.size == -1:
|
||||
# Just use the whole dataset
|
||||
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
|
||||
else:
|
||||
# Use only partition, calculate offsets
|
||||
self.chunk_offsets = self.get_offsets()
|
||||
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
||||
|
||||
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
|
||||
|
||||
def get_offsets(self):
|
||||
"""
|
||||
Calculate for each chunk how many bytes came before it
|
||||
"""
|
||||
offsets = [0]
|
||||
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
|
||||
idx = len(offsets) - 1
|
||||
offsets.append(offsets[idx] + len(self.data[idx]))
|
||||
print(offsets)
|
||||
return offsets
|
||||
|
||||
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
||||
item = ''
|
||||
|
||||
# Determine first chunk in which item is located
|
||||
chunk_idx = 0
|
||||
while idx >= offsets[chunk_idx]:
|
||||
chunk_idx += 1
|
||||
chunk_idx -= 1
|
||||
|
||||
# Extract item from chunks
|
||||
chunk = str(self.data[chunk_idx])
|
||||
chunk_start = offsets[chunk_idx]
|
||||
|
||||
chunk_item_start = idx - chunk_start
|
||||
item_len_remaining = context_length + 1
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
while chunk_item_start + item_len_remaining > len(chunk):
|
||||
adding_now_len = len(chunk) - chunk_item_start
|
||||
item += chunk[chunk_item_start:]
|
||||
|
||||
chunk_idx += 1
|
||||
chunk = str(self.data[chunk_idx])
|
||||
|
||||
chunk_item_start = 0
|
||||
item_len_remaining -= adding_now_len
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
|
||||
|
||||
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
|
||||
|
||||
# Transform to tensor
|
||||
data = ''.join(item).encode('utf-8', errors='replace')
|
||||
t = torch.tensor(list(data), dtype=torch.long)
|
||||
x, y = t[:-1], t[-1]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, Features, Value
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
|
@ -10,33 +10,48 @@ class EnWik9DataSet(Dataset):
|
|||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
|
||||
"""
|
||||
def __init__(self, root: str | None = None, transform: Callable | None = None):
|
||||
super().__init__('enwik9', root, transform)
|
||||
|
||||
# HuggingFace dataset: string text
|
||||
data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size)
|
||||
|
||||
# Extract raw text
|
||||
text = data["text"]
|
||||
|
||||
# Convert text (Python string) → bytes → tensor of ints 0–255
|
||||
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
|
||||
byte_data = "".join(text).encode("utf-8", errors="replace")
|
||||
self.data = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
# Don't pass split here, dataset only contains training
|
||||
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
|
||||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# number of sliding windows
|
||||
return len(self.data) - self.context_length
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# context window
|
||||
x = self.data[idx: idx + self.context_length]
|
||||
|
||||
# next byte target
|
||||
y = self.data[idx + self.context_length]
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,63 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from lorem.text import TextLorem
|
||||
from tqdm import tqdm
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512):
|
||||
super().__init__('lorem_ipsum', root, transform)
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
|
||||
# Generate text and convert to bytes
|
||||
_lorem = TextLorem()
|
||||
_text = ' '.join(_lorem._word() for _ in range(size))
|
||||
|
||||
# Convert text to bytes (UTF-8 encoded)
|
||||
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# Number of possible sequences of length sequence_length
|
||||
return self.dataset.size(0) - self.context_length
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.dataset[idx: idx + self.context_length]
|
||||
y = self.dataset[idx + self.context_length]
|
||||
# Get sequence of characters
|
||||
# x_str = self.text[idx: idx + self.context_length]
|
||||
# y_char = self.text[idx + self.context_length]
|
||||
#
|
||||
# # Convert to tensors
|
||||
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
|
||||
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
|
||||
#
|
||||
# if self.transform is not None:
|
||||
# x = self.transform(x)
|
||||
#
|
||||
# return x, y
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform is not None:
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch import Tensor
|
||||
from datasets import load_dataset, Value, Features
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
|
@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
stage: str = 'stage2'):
|
||||
super().__init__('open_genome', root, transform)
|
||||
size: int = -1,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
|
||||
data = load_dataset("LongSafari/open-genome", stage)
|
||||
self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace')
|
||||
|
||||
self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long)
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
||||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) - self.context_length
|
||||
self.process_data()
|
||||
|
||||
def __getitem__(self, item):
|
||||
x = self.data[item: item + self.context_length]
|
||||
y = self.data[item + self.context_length]
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# return len(self.data) - self.context_length
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[idx:idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from trainers import OptunaTrainer, Trainer, FullTrainer
|
|||
|
||||
def parse_arguments():
|
||||
parser = ArgumentParser(prog="NeuralCompression")
|
||||
parser.add_argument("--debug", "-d", action="store_true", required=False,
|
||||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
|
||||
|
|
@ -18,7 +20,7 @@ def parse_arguments():
|
|||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model-path", type=str, required=True,
|
||||
modelparser.add_argument("--model-path", type=str, required=False,
|
||||
help="Path to the model to load/save")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
|
|
@ -33,6 +35,8 @@ def parse_arguments():
|
|||
help="Only fetch the dataset, then exit")
|
||||
|
||||
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
|
||||
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
|
||||
# TODO
|
||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||
|
|
@ -44,7 +48,7 @@ def parse_arguments():
|
|||
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 64
|
||||
BATCH_SIZE = 2
|
||||
|
||||
# hyper parameters
|
||||
context_length = 128
|
||||
|
|
@ -57,9 +61,18 @@ def main():
|
|||
DEVICE = "cpu"
|
||||
print(f"Running on device: {DEVICE}...")
|
||||
|
||||
dataset_common_args = {
|
||||
'root': args.data_root,
|
||||
'transform': lambda x: x.to(DEVICE)
|
||||
}
|
||||
|
||||
if args.debug:
|
||||
dataset_common_args['size'] = 2**10
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if args.dataset in dataset_called:
|
||||
dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE))
|
||||
training_set = dataset_called[args.dataset](split='train', **dataset_common_args)
|
||||
validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args)
|
||||
else:
|
||||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
|
@ -68,16 +81,10 @@ def main():
|
|||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
dataset_length = len(dataset)
|
||||
print(f"Dataset size = {dataset_length}")
|
||||
|
||||
training_size = ceil(0.8 * dataset_length)
|
||||
|
||||
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
|
||||
|
||||
train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size])
|
||||
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
|
||||
training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
|
|
@ -85,8 +92,9 @@ def main():
|
|||
print("Loading the model...")
|
||||
model = torch.load(args.model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
|
||||
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
|
||||
|
||||
print("Training")
|
||||
trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -35,6 +35,11 @@ def objective_function(
|
|||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials is not None else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
|
|
@ -47,7 +52,7 @@ class OptunaTrainer(Trainer):
|
|||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=20
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
# neural compression
|
||||
|
||||
Example usage:
|
||||
|
||||
```shell
|
||||
python main_cnn.py --debug train --dataset enwik9 --method optuna
|
||||
```
|
||||
|
||||
## Running locally
|
||||
|
||||
```
|
||||
|
|
|
|||
Reference in a new issue