Streamline datasets
This commit is contained in:
parent
849bcd7b77
commit
befb1a96a5
8 changed files with 222 additions and 64 deletions
|
|
@ -35,6 +35,11 @@ def objective_function(
|
|||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials is not None else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
|
|
@ -47,7 +52,7 @@ class OptunaTrainer(Trainer):
|
|||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=20
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
|
|
|
|||
Reference in a new issue