From 926cde17d3aae6db08bc4e927ab7dc7b86671acd Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Tue, 9 Dec 2025 15:10:12 +0100 Subject: [PATCH 1/2] Checkpoint --- main.py | 24 ++++++++++++++---------- src/args.py | 2 +- src/utils/utils.py | 11 +++++++++-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index df5f62d..5c74ba8 100644 --- a/main.py +++ b/main.py @@ -17,23 +17,27 @@ def main(): match args.mode: case 'train': train( - device = device, - dataset = args.dataset, - data_root = args.data_root, - n_trials = 3 if args.debug else None, - size = 2**10 if args.debug else None, - method = args.method, + device=device, + dataset=args.dataset, + data_root=args.data_root, + n_trials=3 if args.debug else None, + size=2 ** 10 if args.debug else None, + method=args.method, model_name=args.model, - model_path = args.model_load_path, - model_out = args.model_save_path + model_path=args.model_load_path, + model_out=args.model_save_path ) case 'compress': - compress(args.input_file) + compress(device=device, + model_path=args.model_load_path, + input_file=args.input_file, + output_file=args.output_file + ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") - + print("Done") diff --git a/src/args.py b/src/args.py index 698089a..5511906 100644 --- a/src/args.py +++ b/src/args.py @@ -19,7 +19,7 @@ def parse_arguments(): help="Which model to use") modelparser.add_argument("--model-load-path", type=str, required=False, help="Filepath to the model to load") - modelparser.add_argument("--model-save-path", type=str, required=True, + modelparser.add_argument("--model-save-path", type=str, required=False, help="Filepath to the model to save") fileparser = ArgumentParser(add_help=False) diff --git a/src/utils/utils.py b/src/utils/utils.py index df27ee5..24fa61d 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,5 @@ +from os import path + import torch from torch.utils.data import TensorDataset import matplotlib.pyplot as plt @@ -14,7 +16,7 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]): plt.hist(range(from_to[0], from_to[1]), weights=probabilities) plt.show() -def print_losses(train_losses: list[float], validation_losses: list[float], show=False): +def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False): plt.plot(train_losses, label="Training loss") plt.plot(validation_losses, label="Validation loss") plt.xlabel("Epoch") @@ -23,7 +25,12 @@ def print_losses(train_losses: list[float], validation_losses: list[float], show if show: plt.show() - plt.savefig("losses.png") + + if filename is None: + filename = path.join("results", "losses.png") + + print(f"Saving losses to {filename}...") + plt.savefig(filename) def load_data(path: str) -> bytes: From 291be01069118f7cdea63cc1d3de3c4dda59797d Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Tue, 9 Dec 2025 22:08:52 +0100 Subject: [PATCH 2/2] fix: Offsets --- src/dataset_loaders/Dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataset_loaders/Dataset.py b/src/dataset_loaders/Dataset.py index 63763af..bc643dd 100644 --- a/src/dataset_loaders/Dataset.py +++ b/src/dataset_loaders/Dataset.py @@ -49,12 +49,12 @@ class Dataset(TorchDataset, ABC): return len(self.dataset) def process_data(self): + self.chunk_offsets = self.get_offsets() if self.size == -1: # Just use the whole dataset self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') else: # Use only partition, calculate offsets - self.chunk_offsets = self.get_offsets() self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)