From 9dff723bbaad6434ccfe241413ffcdee0436a873 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Fri, 5 Dec 2025 12:38:10 +0100 Subject: [PATCH] chore: Start compression path --- main.py | 115 +++++++++---------------------------------------- src/args.py | 42 ++++++++++++++++++ src/process.py | 23 ++++++++++ src/train.py | 61 ++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 94 deletions(-) create mode 100644 src/args.py create mode 100644 src/process.py create mode 100644 src/train.py diff --git a/main.py b/main.py index d43cab5..db9147b 100644 --- a/main.py +++ b/main.py @@ -1,108 +1,35 @@ -from argparse import ArgumentParser -from math import ceil - import torch -from torch.utils.data import DataLoader -from dataset_loaders import dataset_called -from trainers import OptunaTrainer, Trainer, FullTrainer - - -def parse_arguments(): - parser = ArgumentParser(prog="NeuralCompression") - parser.add_argument("--debug", "-d", action="store_true", required=False, - help="Enable debug mode: smaller datasets, more information") - parser.add_argument("--verbose", "-v", action="store_true", required=False, - help="Enable verbose mode") - - dataparser = ArgumentParser(add_help=False) - dataparser.add_argument("--data-root", type=str, required=False) - dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) - - modelparser = ArgumentParser(add_help=False) - modelparser.add_argument("--model-path", type=str, required=False, - help="Path to the model to load/save") - - fileparser = ArgumentParser(add_help=False) - fileparser.add_argument("--input-file", "-i", required=False, type=str) - fileparser.add_argument("--output-file", "-o", required=False, type=str) - - subparsers = parser.add_subparsers(dest="mode", required=True, - help="Mode to run in") - - # TODO - fetch_parser = subparsers.add_parser("fetch", parents=[dataparser], - help="Only fetch the dataset, then exit") - - train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser]) - train_parser.add_argument("--method", choices=["optuna", "full"], required=True, - help="Method to use for training") - - # TODO - compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser]) - - # TODO - decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser]) - - return parser.parse_args() +from src.args import parse_arguments +from src.process import compress +from src.train import train def main(): - BATCH_SIZE = 2 - - # hyper parameters - context_length = 128 - - args = parse_arguments() + args, print_help = parse_arguments() if torch.accelerator.is_available(): - DEVICE = torch.accelerator.current_accelerator().type + device = torch.accelerator.current_accelerator().type else: - DEVICE = "cpu" - print(f"Running on device: {DEVICE}...") + device = "cpu" + print(f"Running on device: {device}...") - dataset_common_args = { - 'root': args.data_root, - 'transform': lambda x: x.to(DEVICE) - } + match args.mode: + case 'train': + train( + device = device, + dataset = args.dataset, + data_root = args.data_root, + n_trials = 3 if args.debug else None, + size = 2**10 if args.debug else None, + model_path = args.model_path + ) - if args.debug: - dataset_common_args['size'] = 2**10 + case 'compress': + compress(args.input_file) - print("Loading in the dataset...") - if args.dataset in dataset_called: - training_set = dataset_called[args.dataset](split='train', **dataset_common_args) - validate_set = dataset_called[args.dataset](split='validation', **dataset_common_args) - else: - # TODO Allow to import arbitrary files - raise NotImplementedError(f"Importing external datasets is not implemented yet") - - if args.mode == 'fetch': - # TODO More to earlier in chain, because now everything is converted into tensors as well? - exit(0) - - print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}") - training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True) - validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False) - - loss_fn = torch.nn.CrossEntropyLoss() - - model = None - if args.model_path is not None: - print("Loading the models...") - model = torch.load(args.model_path) - - trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer() - - print("Training") - trainer.execute( - model=model, - train_loader=training_loader, - validation_loader=validation_loader, - loss_fn=loss_fn, - n_epochs=200, - device=DEVICE - ) + case _: + raise NotImplementedError(f"Mode {args.mode} is not implemented yet") if __name__ == "__main__": diff --git a/src/args.py b/src/args.py new file mode 100644 index 0000000..55a72c2 --- /dev/null +++ b/src/args.py @@ -0,0 +1,42 @@ +from argparse import ArgumentParser + +from src.dataset_loaders import dataset_called + + +def parse_arguments(): + parser = ArgumentParser(prog="NeuralCompression") + parser.add_argument("--debug", "-d", action="store_true", required=False, + help="Enable debug mode: smaller datasets, more information") + parser.add_argument("--verbose", "-v", action="store_true", required=False, + help="Enable verbose mode") + + dataparser = ArgumentParser(add_help=False) + dataparser.add_argument("--data-root", type=str, required=False) + dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) + + modelparser = ArgumentParser(add_help=False) + modelparser.add_argument("--model-path", type=str, required=False, + help="Path to the model to load/save") + + fileparser = ArgumentParser(add_help=False) + fileparser.add_argument("--input-file", "-i", required=False, type=str) + fileparser.add_argument("--output-file", "-o", required=False, type=str) + + subparsers = parser.add_subparsers(dest="mode", required=True, + help="Mode to run in") + + # TODO + fetch_parser = subparsers.add_parser("fetch", parents=[dataparser], + help="Only fetch the dataset, then exit") + + train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser]) + train_parser.add_argument("--method", choices=["optuna", "full"], required=True, + help="Method to use for training") + + # TODO + compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser]) + + # TODO + decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser]) + + return parser.parse_args(), parser.print_help diff --git a/src/process.py b/src/process.py new file mode 100644 index 0000000..cee5a88 --- /dev/null +++ b/src/process.py @@ -0,0 +1,23 @@ +import torch + + +def compress( + input_file: str | None = None +): + if input_file: + with open(input_file, "rb") as file: + byte_data = file.read() + else: + # Read from stdin + text = input() + byte_data = text.encode('utf-8', errors='replace') + + tensor = torch.tensor(list(byte_data), dtype=torch.long) + print(tensor) + + # TODO Feed to model for compression, store result + return + + +def decompress(): + return NotImplementedError("Decompression is not implemented yet") diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..1faffb9 --- /dev/null +++ b/src/train.py @@ -0,0 +1,61 @@ +import torch +from torch.utils.data import DataLoader + +from src.dataset_loaders import dataset_called +from src.trainers import OptunaTrainer, Trainer, FullTrainer + + +def train( + device, + dataset: str, + data_root: str, + n_trials: int | None = None, + size: int | None = None, + mode: str = "train", + method: str = 'optuna', + model_path: str | None = None, +): + batch_size = 2 + + dataset_common_args = { + 'root': data_root, + 'transform': lambda x: x.to(device), + } + + if size: + dataset_common_args['size'] = size + + print("Loading in the dataset...") + if dataset in dataset_called: + training_set = dataset_called[dataset](split='train', **dataset_common_args) + validate_set = dataset_called[dataset](split='validation', **dataset_common_args) + else: + # TODO Allow to import arbitrary files + raise NotImplementedError(f"Importing external datasets is not implemented yet") + + if mode == 'fetch': + # TODO More to earlier in chain, because now everything is converted into tensors as well? + exit(0) + + print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}") + training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True) + validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False) + + loss_fn = torch.nn.CrossEntropyLoss() + + model = None + if model_path is not None: + print("Loading the models...") + model = torch.load(model_path) + + trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer() + + print("Training") + trainer.execute( + model=model, + train_loader=training_loader, + validation_loader=validation_loader, + loss_fn=loss_fn, + n_epochs=200, + device=device + )