feat: Size arg
This commit is contained in:
parent
6c5908e6ae
commit
67310ce4f4
2 changed files with 6 additions and 1 deletions
5
main.py
5
main.py
|
|
@ -12,11 +12,14 @@ def main():
|
||||||
|
|
||||||
match args.mode:
|
match args.mode:
|
||||||
case 'train':
|
case 'train':
|
||||||
size = None
|
size = args.size
|
||||||
if args.method == 'optuna':
|
if args.method == 'optuna':
|
||||||
size = 2 ** 12
|
size = 2 ** 12
|
||||||
|
print(f"Using size {size} for optuna (was {args.size})")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
size = 2 ** 10
|
size = 2 ** 10
|
||||||
|
print(f"Using size {size} for debug (was {args.size})")
|
||||||
|
|
||||||
train(
|
train(
|
||||||
device=device,
|
device=device,
|
||||||
dataset=args.dataset,
|
dataset=args.dataset,
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ def parse_arguments():
|
||||||
train_parser.add_argument("--method",
|
train_parser.add_argument("--method",
|
||||||
choices=["fetch", "optuna", "full"], required=True,
|
choices=["fetch", "optuna", "full"], required=True,
|
||||||
help="Method to use for training")
|
help="Method to use for training")
|
||||||
|
train_parser.add_argument("--size", "-s", type=int, required=False,
|
||||||
|
help="Size of the subset of the dataset to use")
|
||||||
|
|
||||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue