import torch from src.args import parse_arguments from src.process import compress from src.train import train def main(): args, print_help = parse_arguments() if torch.accelerator.is_available(): device = torch.accelerator.current_accelerator().type else: device = "cpu" print(f"Running on device: {device}...") match args.mode: case 'train': train( device=device, dataset=args.dataset, data_root=args.data_root, n_trials=3 if args.debug else None, size=2 ** 10 if args.debug else None, method=args.method, model_name=args.model, model_path=args.model_load_path, model_out=args.model_save_path ) case 'compress': compress(device=device, model_path=args.model_load_path, input_file=args.input_file, output_file=args.output_file ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") print("Done") if __name__ == "__main__": main()