feat: Context CLI arg
This commit is contained in:
parent
cd74949b74
commit
a4583d402b
8 changed files with 38 additions and 31 deletions
|
|
@ -14,12 +14,13 @@ def train(
|
|||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
context_length: int | None = None,
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
):
|
||||
batch_size = 2
|
||||
batch_size = 64
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
|
|
@ -38,6 +39,9 @@ def train(
|
|||
if size:
|
||||
dataset_common_args['size'] = size
|
||||
|
||||
if context_length:
|
||||
dataset_common_args['context_length'] = context_length
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if dataset in dataset_called:
|
||||
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
||||
|
|
|
|||
Reference in a new issue