feat: Context CLI arg

This commit is contained in:
Tibo De Peuter 2025-12-11 13:58:38 +01:00
parent cd74949b74
commit a4583d402b
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 38 additions and 31 deletions

13
main.py
View file

@ -14,11 +14,13 @@ def main():
case 'train':
size = int(args.size) if args.size else None
if args.method == 'optuna':
size = 2 ** 12
print(f"Using size {size} for optuna (was {args.size})")
size = min(size, 2 ** 12) if size else 2 ** 12
if size != args.size:
print(f"Using size {size} for optuna (was {args.size})")
if args.debug:
size = 2 ** 10
print(f"Using size {size} for debug (was {args.size})")
size = min(size, 2 ** 10) if size else 2 ** 10
if size != args.size:
print(f"Using size {size} for debug (was {args.size})")
train(
device=device,
@ -29,7 +31,8 @@ def main():
method=args.method,
model_name=args.model,
model_path=args.model_load_path,
model_out=args.model_save_path
model_out=args.model_save_path,
context_length=args.context
)
case 'compress':