diff --git a/main.py b/main.py index 7ce5560..41fdb8e 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ def main(): match args.mode: case 'train': - size = args.size + size = int(args.size) if args.size else None if args.method == 'optuna': size = 2 ** 12 print(f"Using size {size} for optuna (was {args.size})")