From 67310ce4f4d787efaf2fdf705521ea1af28402e4 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 16:29:39 +0100 Subject: [PATCH] feat: Size arg --- main.py | 5 ++++- src/args.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 23a9850..7ce5560 100644 --- a/main.py +++ b/main.py @@ -12,11 +12,14 @@ def main(): match args.mode: case 'train': - size = None + size = args.size if args.method == 'optuna': size = 2 ** 12 + print(f"Using size {size} for optuna (was {args.size})") if args.debug: size = 2 ** 10 + print(f"Using size {size} for debug (was {args.size})") + train( device=device, dataset=args.dataset, diff --git a/src/args.py b/src/args.py index 5511906..3fc8325 100644 --- a/src/args.py +++ b/src/args.py @@ -35,6 +35,8 @@ def parse_arguments(): train_parser.add_argument("--method", choices=["fetch", "optuna", "full"], required=True, help="Method to use for training") + train_parser.add_argument("--size", "-s", type=int, required=False, + help="Size of the subset of the dataset to use") compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])