diff --git a/main.py b/main.py index 23a9850..7ce5560 100644 --- a/main.py +++ b/main.py @@ -12,11 +12,14 @@ def main(): match args.mode: case 'train': - size = None + size = args.size if args.method == 'optuna': size = 2 ** 12 + print(f"Using size {size} for optuna (was {args.size})") if args.debug: size = 2 ** 10 + print(f"Using size {size} for debug (was {args.size})") + train( device=device, dataset=args.dataset, diff --git a/src/args.py b/src/args.py index 5511906..3fc8325 100644 --- a/src/args.py +++ b/src/args.py @@ -35,6 +35,8 @@ def parse_arguments(): train_parser.add_argument("--method", choices=["fetch", "optuna", "full"], required=True, help="Method to use for training") + train_parser.add_argument("--size", "-s", type=int, required=False, + help="Size of the subset of the dataset to use") compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])