This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../main.py
2025-12-09 15:10:12 +01:00

45 lines
1.2 KiB
Python

import torch
from src.args import parse_arguments
from src.process import compress
from src.train import train
def main():
args, print_help = parse_arguments()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator().type
else:
device = "cpu"
print(f"Running on device: {device}...")
match args.mode:
case 'train':
train(
device=device,
dataset=args.dataset,
data_root=args.data_root,
n_trials=3 if args.debug else None,
size=2 ** 10 if args.debug else None,
method=args.method,
model_name=args.model,
model_path=args.model_load_path,
model_out=args.model_save_path
)
case 'compress':
compress(device=device,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file
)
case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
print("Done")
if __name__ == "__main__":
main()