This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../CNN-model/main_cnn.py
2025-11-07 23:17:29 +01:00

35 lines
1 KiB
Python

import torch
import torch.nn as nn
import torch.functional as F
import optuna.trial as tr
from torch.utils.data import DataLoader
from optuna_trial import create_model
from data_utils import make_context_pairs
# hyper parameters
context_length = 128
def train_and_eval(
model: nn.Module,
training_data: bytes,
validation_data: bytes,
batch_size: int,
epochs: int = 100,
learning_rate: float = 1e-3,
device: torch.device = torch.device("cpu")
):
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length))
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length))
for epoch in range(epochs):
model.train()
def objective_function(trial: tr.Trial):
model = create_model(trial)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")