feat: Context CLI arg

This commit is contained in:
Tibo De Peuter 2025-12-11 13:58:38 +01:00
parent cd74949b74
commit a4583d402b
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 38 additions and 31 deletions

View file

@ -14,12 +14,13 @@ def train(
data_root: str,
n_trials: int | None = None,
size: int | None = None,
context_length: int | None = None,
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
):
batch_size = 2
batch_size = 64
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
@ -38,6 +39,9 @@ def train(
if size:
dataset_common_args['size'] = size
if context_length:
dataset_common_args['context_length'] = context_length
print("Loading in the dataset...")
if dataset in dataset_called:
training_set = dataset_called[dataset](split='train', **dataset_common_args)