fix: Properly pass device
This commit is contained in:
parent
28ae8191ad
commit
8311eabd4d
3 changed files with 2 additions and 3 deletions
|
|
@ -20,7 +20,7 @@ class FullTrainer(Trainer):
|
||||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
|
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
||||||
print_losses(train_loss, val_loss)
|
print_losses(train_loss, val_loss)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def objective_function(
|
||||||
device: str
|
device: str
|
||||||
):
|
):
|
||||||
model = create_model(trial, model).to(device)
|
model = create_model(trial, model).to(device)
|
||||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||||
return min(validation_loss)
|
return min(validation_loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ def train(
|
||||||
weight_decay: float = 1e-8,
|
weight_decay: float = 1e-8,
|
||||||
device="cuda"
|
device="cuda"
|
||||||
) -> tuple[list[float], list[float]]:
|
) -> tuple[list[float], list[float]]:
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue