from src.args import parse_arguments from src.process import compress, decompress from src.train import train from src.utils import determine_device def main(): args, print_help = parse_arguments() device = args.device or determine_device() print(f"Running on device: {device}...") match args.mode: case 'train': size = int(args.size) if args.size else None if args.method == 'optuna': size = min(size, 2 ** 12) if size else 2 ** 12 if size != args.size: print(f"Using size {size} for optuna (was {args.size})") if args.debug: size = min(size, 2 ** 10) if size else 2 ** 10 if size != args.size: print(f"Using size {size} for debug (was {args.size})") train( device=device, dataset=args.dataset, data_root=args.data_root, n_trials=3 if args.debug else None, size=size, method=args.method, model_name=args.model, model_path=args.model_load_path, model_out=args.model_save_path, context_length=args.context, results_dir=args.results ) case 'compress': compress(device=device, model_name=args.model, model_path=args.model_load_path, input_file=args.input_file, output_file=args.output_file, context_length=args.context ) case 'decompress': decompress( device=device, model_name=args.model, model_path=args.model_load_path, input_file=args.input_file, output_file=args.output_file, context_length=args.context ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") print("Done") if __name__ == "__main__": main()