feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
|
|
@ -3,60 +3,73 @@ from typing import Callable
|
|||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from ..models.cnn import CNNPredictor
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, Transformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=hidden_dim,
|
||||
embed_dim=embedding_dim,
|
||||
)
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
match model.__class__:
|
||||
case CNNPredictor.__class__:
|
||||
return model(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case Transformer.__class__:
|
||||
nhead = trial.suggest_int("nhead", 2, 8, log=True)
|
||||
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
|
||||
return model(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
model: Model,
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials is not None else 20
|
||||
self.n_trials = n_trials if n_trials else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
) -> nn.Module:
|
||||
study = optuna.create_study(direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
best_model = model(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||
|
||||
return best_model
|
||||
|
|
|
|||
Reference in a new issue