diff --git a/main.py b/main.py index df5f62d..5c74ba8 100644 --- a/main.py +++ b/main.py @@ -17,23 +17,27 @@ def main(): match args.mode: case 'train': train( - device = device, - dataset = args.dataset, - data_root = args.data_root, - n_trials = 3 if args.debug else None, - size = 2**10 if args.debug else None, - method = args.method, + device=device, + dataset=args.dataset, + data_root=args.data_root, + n_trials=3 if args.debug else None, + size=2 ** 10 if args.debug else None, + method=args.method, model_name=args.model, - model_path = args.model_load_path, - model_out = args.model_save_path + model_path=args.model_load_path, + model_out=args.model_save_path ) case 'compress': - compress(args.input_file) + compress(device=device, + model_path=args.model_load_path, + input_file=args.input_file, + output_file=args.output_file + ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") - + print("Done") diff --git a/src/args.py b/src/args.py index 698089a..5511906 100644 --- a/src/args.py +++ b/src/args.py @@ -19,7 +19,7 @@ def parse_arguments(): help="Which model to use") modelparser.add_argument("--model-load-path", type=str, required=False, help="Filepath to the model to load") - modelparser.add_argument("--model-save-path", type=str, required=True, + modelparser.add_argument("--model-save-path", type=str, required=False, help="Filepath to the model to save") fileparser = ArgumentParser(add_help=False) diff --git a/src/dataset_loaders/Dataset.py b/src/dataset_loaders/Dataset.py index 7bd3c2d..e37f0cc 100644 --- a/src/dataset_loaders/Dataset.py +++ b/src/dataset_loaders/Dataset.py @@ -50,12 +50,12 @@ class Dataset(TorchDataset, ABC): return len(self.dataset) def process_data(self): + self.chunk_offsets = self.get_offsets() if self.size == -1: # Just use the whole dataset self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') else: # Use only partition, calculate offsets - self.chunk_offsets = self.get_offsets() self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy diff --git a/src/utils/utils.py b/src/utils/utils.py index df27ee5..24fa61d 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,5 @@ +from os import path + import torch from torch.utils.data import TensorDataset import matplotlib.pyplot as plt @@ -14,7 +16,7 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]): plt.hist(range(from_to[0], from_to[1]), weights=probabilities) plt.show() -def print_losses(train_losses: list[float], validation_losses: list[float], show=False): +def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False): plt.plot(train_losses, label="Training loss") plt.plot(validation_losses, label="Validation loss") plt.xlabel("Epoch") @@ -23,7 +25,12 @@ def print_losses(train_losses: list[float], validation_losses: list[float], show if show: plt.show() - plt.savefig("losses.png") + + if filename is None: + filename = path.join("results", "losses.png") + + print(f"Saving losses to {filename}...") + plt.savefig(filename) def load_data(path: str) -> bytes: