feat: Results dir

This commit is contained in:
Tibo De Peuter 2025-12-11 15:38:43 +01:00
parent a4583d402b
commit 97e84a97db
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 12 additions and 6 deletions

View file

@ -18,7 +18,8 @@ def train(
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
model_out: str | None = None,
results_dir: str = 'results'
):
batch_size = 64
@ -58,7 +59,7 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
print("Training")
best_model = trainer.execute(