fix: Properly pass device

This commit is contained in:
Tibo De Peuter 2025-12-09 14:50:25 +01:00
parent 28ae8191ad
commit 8311eabd4d
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
3 changed files with 2 additions and 3 deletions

View file

@ -20,7 +20,7 @@ class FullTrainer(Trainer):
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
print_losses(train_loss, val_loss)
return model

View file

@ -39,7 +39,7 @@ def objective_function(
device: str
):
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss)

View file

@ -15,7 +15,6 @@ def train(
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)