import torch import torch.nn as nn import torch.nn.functional as F import optuna.trial as tr from torch.utils.data import DataLoader from tqdm import tqdm import argparse from optuna_trial import create_model from utils import make_context_pairs, load_data import optuna # hyper parameters context_length = 128 def train_and_eval( model: nn.Module, training_data: bytes, validation_data: bytes, batch_size: int, epochs: int = 100, learning_rate: float = 1e-3, device: torch.device = torch.device("cpu") ) -> dict: model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size) validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size) training_losses = [] validation_losses = [] best_val_loss = float("inf") for epoch in range(epochs): model.train() train_loss = 0 for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"): x, y = x.to(device), y.to(device) prediction = model(x) loss = F.cross_entropy(prediction, y) train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() training_losses.append(train_loss / len(training_loader)) model.eval() with torch.no_grad(): val_loss = 0 for x, y in validation_loader: x, y = x.to(device), y.to(device) prediction = model(x) loss = F.cross_entropy(prediction, y) val_loss += loss.item() validation_losses.append(val_loss / len(validation_loader)) if validation_losses[-1] < best_val_loss: best_val_loss = validation_losses[-1] return { "training_losses": training_losses, "validation_losses": validation_losses, "best_validation_loss": best_val_loss } def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int): model = create_model(trial) result = train_and_eval(model, train_data, validation_data, batch_size) return result["best_validation_loss"] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train-data", type=str, required=True) parser.add_argument("--validation-data", type=str, required=True) parser.add_argument("--batch-size", type=int, default=128) args = parser.parse_args() print(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data = load_data(args.train_data) validation_data = load_data(args.validation_data) batch_size = args.batch_size print(f"training data length: {len(train_data)}") print(f"validation data length: {len(validation_data)}") print(f"batch size: {batch_size}") study = optuna.create_study(study_name="CNN network",direction="minimize") study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)