fix: Properly pass device

This commit is contained in:
Tibo De Peuter 2025-12-09 14:50:25 +01:00
parent 28ae8191ad
commit 8311eabd4d
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
3 changed files with 2 additions and 3 deletions

View file

@ -39,7 +39,7 @@ def objective_function(
device: str
):
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss)