feat: model --> ready to test train

This commit is contained in:
Robin Meersman 2025-11-08 20:55:05 +01:00
parent 63d1b6f5ae
commit b58682cb49
8 changed files with 382 additions and 17 deletions

View file

@ -4,9 +4,10 @@ import torch.nn.functional as F
import optuna.trial as tr
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
from optuna_trial import create_model
from data_utils import make_context_pairs
from utils import make_context_pairs, load_data
import optuna
# hyper parameters
@ -71,10 +72,22 @@ def objective_function(trial: tr.Trial, train_data: bytes, validation_data: byte
return result["best_validation_loss"]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train-data", type=str, required=True)
parser.add_argument("--validation-data", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=128)
args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = b""
validation_data = b""
batch_size = 0
train_data = load_data(args.train_data)
validation_data = load_data(args.validation_data)
batch_size = args.batch_size
print(f"training data length: {len(train_data)}")
print(f"validation data length: {len(validation_data)}")
print(f"batch size: {batch_size}")
study = optuna.create_study(study_name="CNN network",direction="minimize")
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)