35 lines
1 KiB
Python
35 lines
1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.functional as F
|
|
import optuna.trial as tr
|
|
from torch.utils.data import DataLoader
|
|
|
|
from optuna_trial import create_model
|
|
from data_utils import make_context_pairs
|
|
|
|
# hyper parameters
|
|
context_length = 128
|
|
|
|
def train_and_eval(
|
|
model: nn.Module,
|
|
training_data: bytes,
|
|
validation_data: bytes,
|
|
batch_size: int,
|
|
epochs: int = 100,
|
|
learning_rate: float = 1e-3,
|
|
device: torch.device = torch.device("cpu")
|
|
):
|
|
model.to(device)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length))
|
|
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length))
|
|
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
|
|
|
|
def objective_function(trial: tr.Trial):
|
|
model = create_model(trial)
|
|
|
|
if __name__ == "__main__":
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|