41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
from src.args import parse_arguments
|
|
from src.process import compress
|
|
from src.train import train
|
|
from src.utils import determine_device
|
|
|
|
|
|
def main():
|
|
args, print_help = parse_arguments()
|
|
|
|
device = determine_device()
|
|
print(f"Running on device: {device}...")
|
|
|
|
match args.mode:
|
|
case 'train':
|
|
train(
|
|
device=device,
|
|
dataset=args.dataset,
|
|
data_root=args.data_root,
|
|
n_trials=3 if args.debug else None,
|
|
size=2 ** 10 if args.debug else None,
|
|
method=args.method,
|
|
model_name=args.model,
|
|
model_path=args.model_load_path,
|
|
model_out=args.model_save_path
|
|
)
|
|
|
|
case 'compress':
|
|
compress(device=device,
|
|
model_path=args.model_load_path,
|
|
input_file=args.input_file,
|
|
output_file=args.output_file
|
|
)
|
|
|
|
case _:
|
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|