63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
from src.args import parse_arguments
|
|
from src.process import compress, decompress
|
|
from src.train import train
|
|
from src.utils import determine_device
|
|
|
|
|
|
def main():
|
|
args, print_help = parse_arguments()
|
|
|
|
device = args.device or determine_device()
|
|
print(f"Running on device: {device}...")
|
|
|
|
match args.mode:
|
|
case 'train':
|
|
size = int(args.size) if args.size else None
|
|
if args.method == 'optuna':
|
|
size = min(size, 2 ** 12) if size else 2 ** 12
|
|
if size != args.size:
|
|
print(f"Using size {size} for optuna (was {args.size})")
|
|
if args.debug:
|
|
size = min(size, 2 ** 10) if size else 2 ** 10
|
|
if size != args.size:
|
|
print(f"Using size {size} for debug (was {args.size})")
|
|
|
|
train(
|
|
device=device,
|
|
dataset=args.dataset,
|
|
data_root=args.data_root,
|
|
n_trials=3 if args.debug else None,
|
|
size=size,
|
|
method=args.method,
|
|
model_name=args.model,
|
|
model_path=args.model_load_path,
|
|
model_out=args.model_save_path,
|
|
context_length=args.context,
|
|
results_dir=args.results
|
|
)
|
|
|
|
case 'compress':
|
|
compress(device=device,
|
|
model_name=args.model,
|
|
model_path=args.model_load_path,
|
|
input_file=args.input_file,
|
|
output_file=args.output_file,
|
|
context_length=args.context
|
|
)
|
|
case 'decompress':
|
|
decompress(
|
|
device=device,
|
|
model_name=args.model,
|
|
model_path=args.model_load_path,
|
|
input_file=args.input_file,
|
|
output_file=args.output_file,
|
|
context_length=args.context
|
|
)
|
|
case _:
|
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
|
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|