feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -1,93 +1,55 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna.trial as tr
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
from argparse import ArgumentParser
from math import ceil
from optuna_trial import create_model
from utils import make_context_pairs, load_data
import optuna
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import EnWik9DataSet, LoremIpsumDataset
from trainers import OptunaTrainer, Trainer
BATCH_SIZE = 64
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
# hyper parameters
context_length = 128
def train_and_eval(
model: nn.Module,
training_data: bytes,
validation_data: bytes,
batch_size: int,
epochs: int = 100,
learning_rate: float = 1e-3,
device: torch.device = torch.device("cpu")
) -> dict:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
training_losses = []
validation_losses = []
best_val_loss = float("inf")
for epoch in range(epochs):
model.train()
train_loss = 0
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_losses.append(train_loss / len(training_loader))
model.eval()
with torch.no_grad():
val_loss = 0
for x, y in validation_loader:
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
val_loss += loss.item()
validation_losses.append(val_loss / len(validation_loader))
if validation_losses[-1] < best_val_loss:
best_val_loss = validation_losses[-1]
return {
"training_losses": training_losses,
"validation_losses": validation_losses,
"best_validation_loss": best_val_loss
}
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
model = create_model(trial)
result = train_and_eval(model, train_data, validation_data, batch_size)
return result["best_validation_loss"]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train-data", type=str, required=True)
parser.add_argument("--validation-data", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=128)
print(f"Running on device: {DEVICE}...")
parser = ArgumentParser()
parser.add_argument("--method", choices=["optuna", "train"], required=True)
parser.add_argument("--model-path", type=str, required=False)
args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = load_data(args.train_data)
validation_data = load_data(args.validation_data)
batch_size = args.batch_size
if args.method == "train":
dataset = EnWik9DataSet()
elif args.method == "optuna":
dataset = LoremIpsumDataset()
else:
raise ValueError(f"Unknown method: {args.method}")
print(f"training data length: {len(train_data)}")
print(f"validation data length: {len(validation_data)}")
print(f"batch size: {batch_size}")
dataset_length = len(dataset)
training_size = ceil(0.8 * dataset_length)
study = optuna.create_study(study_name="CNN network",direction="minimize")
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
data = dataset.data["text"]
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
[training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if args.model_path is not None:
model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200,
device=DEVICE
)