fix: Subset Optuna

This commit is contained in:
Tibo De Peuter 2025-12-10 11:09:44 +01:00
parent 51e0ed7fc0
commit 956ff79fa1
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -12,12 +12,17 @@ def main():
match args.mode:
case 'train':
size = None
if args.method == 'optuna':
size = 2 ** 12
if args.debug:
size = 2 ** 10
train(
device=device,
dataset=args.dataset,
data_root=args.data_root,
n_trials=3 if args.debug else None,
size=2 ** 10 if args.debug else None,
size=size,
method=args.method,
model_name=args.model,
model_path=args.model_load_path,