72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
import optuna
|
|
import optuna.trial as tr
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from .train import train
|
|
from .trainer import Trainer
|
|
from ..models import Model, CNNPredictor, ByteTransformer
|
|
|
|
|
|
def create_model(trial: tr.Trial, model: nn.Module):
|
|
match model.__class__:
|
|
case CNNPredictor.__class__:
|
|
return model(
|
|
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
|
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
|
vocab_size=256,
|
|
)
|
|
case ByteTransformer.__class__:
|
|
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
|
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
|
return model(
|
|
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
|
|
nhead=nhead,
|
|
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
|
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
|
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
|
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
|
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
|
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
|
)
|
|
return None
|
|
|
|
|
|
def objective_function(
|
|
trial: tr.Trial,
|
|
training_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
model: Model,
|
|
device: str
|
|
):
|
|
model = create_model(trial, model).to(device)
|
|
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
|
return min(validation_loss)
|
|
|
|
|
|
class OptunaTrainer(Trainer):
|
|
def __init__(self, n_trials: int | None = None):
|
|
super().__init__()
|
|
self.n_trials = n_trials if n_trials else 20
|
|
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
|
|
|
def execute(
|
|
self,
|
|
model: Model,
|
|
train_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
n_epochs: int,
|
|
device: str
|
|
) -> nn.Module:
|
|
study = optuna.create_study(direction="minimize")
|
|
study.optimize(
|
|
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
|
n_trials=self.n_trials
|
|
)
|
|
|
|
best_params = study.best_trial.params
|
|
best_model = model(
|
|
**best_params
|
|
)
|
|
|
|
return best_model
|