This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../main.py

38 lines
993 B
Python

import torch
from src.args import parse_arguments
from src.process import compress
from src.train import train
def main():
args, print_help = parse_arguments()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator().type
else:
device = "cpu"
print(f"Running on device: {device}...")
match args.mode:
case 'train':
train(
device = device,
dataset = args.dataset,
data_root = args.data_root,
n_trials = 3 if args.debug else None,
size = 2**10 if args.debug else None,
model_name=args.model,
model_path = args.model_load_path,
model_out = args.model_save_path
)
case 'compress':
compress(args.input_file)
case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
if __name__ == "__main__":
main()