diff --git a/src/trainers/FullTrainer.py b/src/trainers/FullTrainer.py index cb0dc2c..d3e8bde 100644 --- a/src/trainers/FullTrainer.py +++ b/src/trainers/FullTrainer.py @@ -20,7 +20,7 @@ class FullTrainer(Trainer): raise ValueError("Model must be provided: run optuna optimizations first") model.to(device) - train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs) + train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device) print_losses(train_loss, val_loss) return model diff --git a/src/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py index de273e9..848bd19 100644 --- a/src/trainers/OptunaTrainer.py +++ b/src/trainers/OptunaTrainer.py @@ -39,7 +39,7 @@ def objective_function( device: str ): model = create_model(trial, model).to(device) - _, validation_loss = train(model, training_loader, validation_loader, model.loss_function) + _, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device) return min(validation_loss) diff --git a/src/trainers/train.py b/src/trainers/train.py index be4aa34..3b7ed59 100644 --- a/src/trainers/train.py +++ b/src/trainers/train.py @@ -15,7 +15,6 @@ def train( weight_decay: float = 1e-8, device="cuda" ) -> tuple[list[float], list[float]]: - model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)