import torch import torch.nn as nn import torch.functional as F import optuna.trial as tr from torch.utils.data import DataLoader from optuna_trial import create_model from data_utils import make_context_pairs # hyper parameters context_length = 128 def train_and_eval( model: nn.Module, training_data: bytes, validation_data: bytes, batch_size: int, epochs: int = 100, learning_rate: float = 1e-3, device: torch.device = torch.device("cpu") ): model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length)) validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length)) for epoch in range(epochs): model.train() def objective_function(trial: tr.Trial): model = create_model(trial) if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu")