This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../main.py

53 lines
1.6 KiB
Python

from src.args import parse_arguments
from src.process import compress
from src.train import train
from src.utils import determine_device
def main():
args, print_help = parse_arguments()
device = args.device or determine_device()
print(f"Running on device: {device}...")
match args.mode:
case 'train':
size = int(args.size) if args.size else None
if args.method == 'optuna':
size = min(size, 2 ** 12) if size else 2 ** 12
if size != args.size:
print(f"Using size {size} for optuna (was {args.size})")
if args.debug:
size = min(size, 2 ** 10) if size else 2 ** 10
if size != args.size:
print(f"Using size {size} for debug (was {args.size})")
train(
device=device,
dataset=args.dataset,
data_root=args.data_root,
n_trials=3 if args.debug else None,
size=size,
method=args.method,
model_name=args.model,
model_path=args.model_load_path,
model_out=args.model_save_path,
context_length=args.context,
results_dir=args.results
)
case 'compress':
compress(device=device,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file
)
case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
print("Done")
if __name__ == "__main__":
main()