fix: Properly pass device
This commit is contained in:
parent
28ae8191ad
commit
8311eabd4d
3 changed files with 2 additions and 3 deletions
|
|
@ -20,7 +20,7 @@ class FullTrainer(Trainer):
|
|||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
||||
print_losses(train_loss, val_loss)
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def objective_function(
|
|||
device: str
|
||||
):
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ def train(
|
|||
weight_decay: float = 1e-8,
|
||||
device="cuda"
|
||||
) -> tuple[list[float], list[float]]:
|
||||
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
|
|
|
|||
Reference in a new issue