41 lines
1 KiB
Python
41 lines
1 KiB
Python
import torch
|
|
|
|
from src.args import parse_arguments
|
|
from src.process import compress
|
|
from src.train import train
|
|
|
|
|
|
def main():
|
|
args, print_help = parse_arguments()
|
|
|
|
if torch.accelerator.is_available():
|
|
device = torch.accelerator.current_accelerator().type
|
|
else:
|
|
device = "cpu"
|
|
print(f"Running on device: {device}...")
|
|
|
|
match args.mode:
|
|
case 'train':
|
|
train(
|
|
device = device,
|
|
dataset = args.dataset,
|
|
data_root = args.data_root,
|
|
n_trials = 3 if args.debug else None,
|
|
size = 2**10 if args.debug else None,
|
|
method = args.method,
|
|
model_name=args.model,
|
|
model_path = args.model_load_path,
|
|
model_out = args.model_save_path
|
|
)
|
|
|
|
case 'compress':
|
|
compress(args.input_file)
|
|
|
|
case _:
|
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|