45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
import torch
|
|
|
|
from src.args import parse_arguments
|
|
from src.process import compress
|
|
from src.train import train
|
|
|
|
|
|
def main():
|
|
args, print_help = parse_arguments()
|
|
|
|
if torch.accelerator.is_available():
|
|
device = torch.accelerator.current_accelerator().type
|
|
else:
|
|
device = "cpu"
|
|
print(f"Running on device: {device}...")
|
|
|
|
match args.mode:
|
|
case 'train':
|
|
train(
|
|
device=device,
|
|
dataset=args.dataset,
|
|
data_root=args.data_root,
|
|
n_trials=3 if args.debug else None,
|
|
size=2 ** 10 if args.debug else None,
|
|
method=args.method,
|
|
model_name=args.model,
|
|
model_path=args.model_load_path,
|
|
model_out=args.model_save_path
|
|
)
|
|
|
|
case 'compress':
|
|
compress(device=device,
|
|
model_path=args.model_load_path,
|
|
input_file=args.input_file,
|
|
output_file=args.output_file
|
|
)
|
|
|
|
case _:
|
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|