fix: Properly pass device
This commit is contained in:
parent
28ae8191ad
commit
8311eabd4d
3 changed files with 2 additions and 3 deletions
|
|
@ -39,7 +39,7 @@ def objective_function(
|
|||
device: str
|
||||
):
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
|
|
|
|||
Reference in a new issue