import torch import torch.nn as nn import torch.nn.functional as F import optuna.trial as tr from torch.utils.data import DataLoader from tqdm import tqdm from optuna_trial import create_model from data_utils import make_context_pairs import optuna # hyper parameters context_length = 128 def train_and_eval( model: nn.Module, training_data: bytes, validation_data: bytes, batch_size: int, epochs: int = 100, learning_rate: float = 1e-3, device: torch.device = torch.device("cpu") ) -> dict: model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size) validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size) training_losses = [] validation_losses = [] best_val_loss = float("inf") for epoch in range(epochs): model.train() train_loss = 0 for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"): x, y = x.to(device), y.to(device) prediction = model(x) loss = F.cross_entropy(prediction, y) train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() training_losses.append(train_loss / len(training_loader)) model.eval() with torch.no_grad(): val_loss = 0 for x, y in validation_loader: x, y = x.to(device), y.to(device) prediction = model(x) loss = F.cross_entropy(prediction, y) val_loss += loss.item() validation_losses.append(val_loss / len(validation_loader)) if validation_losses[-1] < best_val_loss: best_val_loss = validation_losses[-1] return { "training_losses": training_losses, "validation_losses": validation_losses, "best_validation_loss": best_val_loss } def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int): model = create_model(trial) result = train_and_eval(model, train_data, validation_data, batch_size) return result["best_validation_loss"] if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data = b"" validation_data = b"" batch_size = 0 study = optuna.create_study(study_name="CNN network",direction="minimize") study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)