from src.args import parse_arguments from src.process import compress from src.train import train from src.utils import determine_device def main(): args, print_help = parse_arguments() device = determine_device() print(f"Running on device: {device}...") match args.mode: case 'train': size = int(args.size) if args.size else None if args.method == 'optuna': size = 2 ** 12 print(f"Using size {size} for optuna (was {args.size})") if args.debug: size = 2 ** 10 print(f"Using size {size} for debug (was {args.size})") train( device=device, dataset=args.dataset, data_root=args.data_root, n_trials=3 if args.debug else None, size=size, method=args.method, model_name=args.model, model_path=args.model_load_path, model_out=args.model_save_path ) case 'compress': compress(device=device, model_path=args.model_load_path, input_file=args.input_file, output_file=args.output_file ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") print("Done") if __name__ == "__main__": main()