feat: Context CLI arg
This commit is contained in:
parent
cd74949b74
commit
a4583d402b
8 changed files with 38 additions and 31 deletions
13
main.py
13
main.py
|
|
@ -14,11 +14,13 @@ def main():
|
|||
case 'train':
|
||||
size = int(args.size) if args.size else None
|
||||
if args.method == 'optuna':
|
||||
size = 2 ** 12
|
||||
print(f"Using size {size} for optuna (was {args.size})")
|
||||
size = min(size, 2 ** 12) if size else 2 ** 12
|
||||
if size != args.size:
|
||||
print(f"Using size {size} for optuna (was {args.size})")
|
||||
if args.debug:
|
||||
size = 2 ** 10
|
||||
print(f"Using size {size} for debug (was {args.size})")
|
||||
size = min(size, 2 ** 10) if size else 2 ** 10
|
||||
if size != args.size:
|
||||
print(f"Using size {size} for debug (was {args.size})")
|
||||
|
||||
train(
|
||||
device=device,
|
||||
|
|
@ -29,7 +31,8 @@ def main():
|
|||
method=args.method,
|
||||
model_name=args.model,
|
||||
model_path=args.model_load_path,
|
||||
model_out=args.model_save_path
|
||||
model_out=args.model_save_path,
|
||||
context_length=args.context
|
||||
)
|
||||
|
||||
case 'compress':
|
||||
|
|
|
|||
Reference in a new issue