19 lines
732 B
Python
19 lines
732 B
Python
import optuna.trial as tr
|
|
from cnn import CNNPredictor
|
|
|
|
def create_model(trial: tr.Trial, vocab_size: int = 256, context_length: int = 128):
|
|
num_layers = trial.suggest_int("num_layers", 1, 6)
|
|
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
|
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
|
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
|
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
|
|
|
return CNNPredictor(
|
|
vocab_size=vocab_size,
|
|
context_length=context_length,
|
|
num_layers=num_layers,
|
|
hidden_dim=hidden_dim,
|
|
kernel_size=kernel_size,
|
|
dropout_prob=dropout_prob,
|
|
use_batchnorm=use_batchnorm
|
|
)
|