diff --git a/CNN-model/dataset_loaders/Dataset.py b/CNN-model/dataset_loaders/Dataset.py index 228fa59..63763af 100644 --- a/CNN-model/dataset_loaders/Dataset.py +++ b/CNN-model/dataset_loaders/Dataset.py @@ -2,28 +2,114 @@ from abc import abstractmethod, ABC from os.path import join, curdir from typing import Callable +import torch +from torch import Tensor from torch.utils.data import Dataset as TorchDataset +from tqdm import tqdm """ Author: Tibo De Peuter """ + + class Dataset(TorchDataset, ABC): """Abstract base class for datasets.""" + @abstractmethod - def __init__(self, name: str, root: str | None, transform: Callable = None): + def __init__(self, + name: str, + root: str | None, + split: str = 'train', + transform: Callable = None, + size: int = -1 + ): """ - :param root: Relative path to the dataset root directory + :param root: Path to the dataset root directory + :param split: The dataset split, e.g. 'train', 'validation', 'test' + :param size: Override the maximum size of the dataset, useful for debugging """ if root is None: root = join(curdir, 'data') self._root = join(root, name) + self.split = split self.transform = transform - self.dataset = None + self.size = size + self.data = None + + self.chunk_offsets: list[int] = [] + self.bytes: bytes = bytes() + self.tensor: Tensor = torch.tensor([]) @property def root(self): return self._root def __len__(self): - return len(self.dataset) \ No newline at end of file + return len(self.dataset) + + def process_data(self): + if self.size == -1: + # Just use the whole dataset + self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') + else: + # Use only partition, calculate offsets + self.chunk_offsets = self.get_offsets() + self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') + + self.tensor = torch.tensor(list(self.bytes), dtype=torch.long) + + def get_offsets(self): + """ + Calculate for each chunk how many bytes came before it + """ + offsets = [0] + while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size): + idx = len(offsets) - 1 + offsets.append(offsets[idx] + len(self.data[idx])) + print(offsets) + return offsets + + def get_chunked_item(self, idx: int, offsets: list[int], context_length: int): + item = '' + + # Determine first chunk in which item is located + chunk_idx = 0 + while idx >= offsets[chunk_idx]: + chunk_idx += 1 + chunk_idx -= 1 + + # Extract item from chunks + chunk = str(self.data[chunk_idx]) + chunk_start = offsets[chunk_idx] + + chunk_item_start = idx - chunk_start + item_len_remaining = context_length + 1 + + assert len(item) + item_len_remaining == context_length + 1 + + while chunk_item_start + item_len_remaining > len(chunk): + adding_now_len = len(chunk) - chunk_item_start + item += chunk[chunk_item_start:] + + chunk_idx += 1 + chunk = str(self.data[chunk_idx]) + + chunk_item_start = 0 + item_len_remaining -= adding_now_len + + assert len(item) + item_len_remaining == context_length + 1 + + item += chunk[chunk_item_start: chunk_item_start + item_len_remaining] + + assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}" + + # Transform to tensor + data = ''.join(item).encode('utf-8', errors='replace') + t = torch.tensor(list(data), dtype=torch.long) + x, y = t[:-1], t[-1] + + if self.transform: + x = self.transform(x) + + return x, y diff --git a/CNN-model/dataset_loaders/EnWik9.py b/CNN-model/dataset_loaders/EnWik9.py index 5059940..0af0be3 100644 --- a/CNN-model/dataset_loaders/EnWik9.py +++ b/CNN-model/dataset_loaders/EnWik9.py @@ -1,7 +1,7 @@ +from math import ceil from typing import Callable -import torch -from datasets import load_dataset +from datasets import load_dataset, Features, Value from .Dataset import Dataset @@ -10,33 +10,48 @@ class EnWik9DataSet(Dataset): """ Hugging Face: https://huggingface.co/datasets/haukur/enwik9 """ - def __init__(self, root: str | None = None, transform: Callable | None = None): - super().__init__('enwik9', root, transform) - # HuggingFace dataset: string text - data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train") + def __init__(self, + root: str | None = None, + split: str = 'train', + transform: Callable | None = None, + size: int = -1 + ): + super().__init__('enwik9', root, split, transform, size) - # Extract raw text - text = data["text"] - - # Convert text (Python string) → bytes → tensor of ints 0–255 - # UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors - byte_data = "".join(text).encode("utf-8", errors="replace") - self.data = torch.tensor(list(byte_data), dtype=torch.long) + print(f"Loading from HuggingFace") + ft = Features({'text': Value('string')}) + # Don't pass split here, dataset only contains training + text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft) + self.data = text_chunks['text'] + self.size = size # Model uses fixed 128-length context self.context_length = 128 + self.process_data() + + # Define splits manually, because they do not exist in the dataset + split_point = ceil(self.chunk_offsets[-1] * 0.8) + + if self.split == 'train': + self.start_byte = 0 + self.end_byte = split_point + elif self.split == 'validation': + self.start_byte = split_point + self.end_byte = self.chunk_offsets[-1] + else: + raise ValueError("split must be 'train' or 'validation'") + + print("Done initializing dataset") + def __len__(self): - # number of sliding windows - return len(self.data) - self.context_length + return self.end_byte - self.start_byte - self.context_length def __getitem__(self, idx): - # context window - x = self.data[idx: idx + self.context_length] - - # next byte target - y = self.data[idx + self.context_length] + # return self.get_chunked_item(idx, self.chunk_offsets, self.context_length) + x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length] + y = self.tensor[self.start_byte + idx + self.context_length] if self.transform: x = self.transform(x) diff --git a/CNN-model/dataset_loaders/LoremIpsumDataset.py b/CNN-model/dataset_loaders/LoremIpsumDataset.py index 0fad99d..5dece41 100644 --- a/CNN-model/dataset_loaders/LoremIpsumDataset.py +++ b/CNN-model/dataset_loaders/LoremIpsumDataset.py @@ -1,32 +1,63 @@ +from math import ceil from typing import Callable -import torch from lorem.text import TextLorem +from tqdm import tqdm from .Dataset import Dataset class LoremIpsumDataset(Dataset): - def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512): - super().__init__('lorem_ipsum', root, transform) + def __init__(self, + root: str | None = None, + split: str = 'train', + transform: Callable = None, + size: int = 2**30 + ): + super().__init__('lorem_ipsum', root, split, transform, size) - # Generate text and convert to bytes _lorem = TextLorem() - _text = ' '.join(_lorem._word() for _ in range(size)) - # Convert text to bytes (UTF-8 encoded) - self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) + self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data")) + self.size = size + self.context_length = 128 + self.process_data() + + split_point = ceil(self.chunk_offsets[-1] * 0.8) + + if self.split == 'train': + self.start_byte = 0 + self.end_byte = split_point + elif self.split == 'validation': + self.start_byte = split_point + self.end_byte = self.chunk_offsets[-1] + else: + raise ValueError("split must be 'train' or 'validation'") + + print("Done initializing dataset") + def __len__(self): - # Number of possible sequences of length sequence_length - return self.dataset.size(0) - self.context_length + return self.end_byte - self.start_byte - self.context_length def __getitem__(self, idx): - x = self.dataset[idx: idx + self.context_length] - y = self.dataset[idx + self.context_length] + # Get sequence of characters + # x_str = self.text[idx: idx + self.context_length] + # y_char = self.text[idx + self.context_length] + # + # # Convert to tensors + # x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long) + # y = torch.tensor(ord(y_char) % 256, dtype=torch.long) + # + # if self.transform is not None: + # x = self.transform(x) + # + # return x, y + x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length] + y = self.tensor[self.start_byte + idx + self.context_length] - if self.transform is not None: + if self.transform: x = self.transform(x) return x, y diff --git a/CNN-model/dataset_loaders/OpenGenomeDataset.py b/CNN-model/dataset_loaders/OpenGenomeDataset.py index 585a799..05ee2b5 100644 --- a/CNN-model/dataset_loaders/OpenGenomeDataset.py +++ b/CNN-model/dataset_loaders/OpenGenomeDataset.py @@ -1,8 +1,6 @@ from typing import Callable -import torch -from datasets import load_dataset -from torch import Tensor +from datasets import load_dataset, Value, Features from .Dataset import Dataset @@ -20,23 +18,32 @@ class OpenGenomeDataset(Dataset): root: str | None = None, split: str = 'train', transform: Callable = None, - stage: str = 'stage2'): - super().__init__('open_genome', root, transform) + size: int = -1, + stage: str = 'stage2' + ): + super().__init__('open_genome', root, split, transform, size) - data = load_dataset("LongSafari/open-genome", stage) - self.__train = ''.join(data[split]['text']).encode('utf-8', errors='replace') - - self.data: Tensor = torch.tensor(bytearray(self.__train), dtype=torch.long) + print(f"Loading from HuggingFace (stage: {stage}, split: {split})") + ft = Features({'text': Value('string')}) + data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft) + self.data = data['text'] + self.size = size # Model uses fixed 128-length context self.context_length = 128 - def __len__(self): - return len(self.data) - self.context_length + self.process_data() - def __getitem__(self, item): - x = self.data[item: item + self.context_length] - y = self.data[item + self.context_length] + print("Done initializing dataset") + + def __len__(self): + # return len(self.data) - self.context_length + return self.chunk_offsets[-1] - self.context_length + + def __getitem__(self, idx): + # return self.get_chunked_item(idx, self.chunk_offsets, self.context_length) + x = self.tensor[idx:idx + self.context_length] + y = self.tensor[idx + self.context_length] if self.transform: x = self.transform(x) diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index a66710a..1789572 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -10,6 +10,8 @@ from trainers import OptunaTrainer, Trainer, FullTrainer def parse_arguments(): parser = ArgumentParser(prog="NeuralCompression") + parser.add_argument("--debug", "-d", action="store_true", required=False, + help="Enable debug mode: smaller datasets, more information") parser.add_argument("--verbose", "-v", action="store_true", required=False, help="Enable verbose mode") @@ -18,7 +20,7 @@ def parse_arguments(): dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) modelparser = ArgumentParser(add_help=False) - modelparser.add_argument("--model-path", type=str, required=True, + modelparser.add_argument("--model-path", type=str, required=False, help="Path to the model to load/save") fileparser = ArgumentParser(add_help=False) @@ -33,6 +35,8 @@ def parse_arguments(): help="Only fetch the dataset, then exit") train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser]) + train_parser.add_argument("--method", choices=["optuna", "full"], required=True, + help="Method to use for training") # TODO compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser]) @@ -44,7 +48,7 @@ def parse_arguments(): def main(): - BATCH_SIZE = 64 + BATCH_SIZE = 2 # hyper parameters context_length = 128 @@ -57,9 +61,18 @@ def main(): DEVICE = "cpu" print(f"Running on device: {DEVICE}...") + dataset_common_args = { + 'root': args.data_root, + 'transform': lambda x: x.to(DEVICE) + } + + if args.debug: + dataset_common_args['size'] = 2**10 + print("Loading in the dataset...") if args.dataset in dataset_called: - dataset = dataset_called[args.dataset](root=args.data_root, transform=lambda x: x.to(DEVICE)) + training_set = dataset_called[args.dataset](split='train', **dataset_common_args) + validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args) else: # TODO Allow to import arbitrary files raise NotImplementedError(f"Importing external datasets is not implemented yet") @@ -68,16 +81,10 @@ def main(): # TODO More to earlier in chain, because now everything is converted into tensors as well? exit(0) - dataset_length = len(dataset) - print(f"Dataset size = {dataset_length}") - - training_size = ceil(0.8 * dataset_length) - - print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}") - - train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size]) - training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) + print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}") + training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True) validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) + loss_fn = torch.nn.CrossEntropyLoss() model = None @@ -85,8 +92,9 @@ def main(): print("Loading the model...") model = torch.load(args.model_path) - trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() + trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer() + print("Training") trainer.execute( model=model, train_loader=training_loader, diff --git a/CNN-model/models/final_model.pt b/CNN-model/models/final_model.pt index 3b0aae5..54d5bcb 100644 Binary files a/CNN-model/models/final_model.pt and b/CNN-model/models/final_model.pt differ diff --git a/CNN-model/trainers/OptunaTrainer.py b/CNN-model/trainers/OptunaTrainer.py index b0d8d9e..6f0b3b9 100644 --- a/CNN-model/trainers/OptunaTrainer.py +++ b/CNN-model/trainers/OptunaTrainer.py @@ -35,6 +35,11 @@ def objective_function( class OptunaTrainer(Trainer): + def __init__(self, n_trials: int | None = None): + super().__init__() + self.n_trials = n_trials if n_trials is not None else 20 + print(f"Creating Optuna trainer(n_trials = {self.n_trials})") + def execute( self, model: nn.Module | None, @@ -47,7 +52,7 @@ class OptunaTrainer(Trainer): study = optuna.create_study(study_name="CNN network", direction="minimize") study.optimize( lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), - n_trials=20 + n_trials=self.n_trials ) best_params = study.best_trial.params diff --git a/README.md b/README.md index 7a3854e..a183021 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ # neural compression +Example usage: + +```shell +python main_cnn.py --debug train --dataset enwik9 --method optuna +``` + ## Running locally ```