import torch from src.args import parse_arguments from src.process import compress from src.train import train def main(): args, print_help = parse_arguments() if torch.accelerator.is_available(): device = torch.accelerator.current_accelerator().type else: device = "cpu" print(f"Running on device: {device}...") match args.mode: case 'train': train( device = device, dataset = args.dataset, data_root = args.data_root, n_trials = 3 if args.debug else None, size = 2**10 if args.debug else None, method = args.method, model_name=args.model, model_path = args.model_load_path, model_out = args.model_save_path ) case 'compress': compress(args.input_file) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") print("Done") if __name__ == "__main__": main()