feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -0,0 +1,5 @@
import lorem
class LoremIpsumDataset:
def __init__(self):
self.data = lorem.text(paragraphs=100)

View file

@ -0,0 +1,2 @@
from EnWik9 import EnWik9DataSet
from LoremIpsumDataset import LoremIpsumDataset

View file

@ -1,93 +1,55 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import optuna.trial as tr
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
from argparse import ArgumentParser
from math import ceil
from optuna_trial import create_model
from utils import make_context_pairs, load_data
import optuna
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import EnWik9DataSet, LoremIpsumDataset
from trainers import OptunaTrainer, Trainer
BATCH_SIZE = 64
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
# hyper parameters
context_length = 128
def train_and_eval(
model: nn.Module,
training_data: bytes,
validation_data: bytes,
batch_size: int,
epochs: int = 100,
learning_rate: float = 1e-3,
device: torch.device = torch.device("cpu")
) -> dict:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
training_losses = []
validation_losses = []
best_val_loss = float("inf")
for epoch in range(epochs):
model.train()
train_loss = 0
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_losses.append(train_loss / len(training_loader))
model.eval()
with torch.no_grad():
val_loss = 0
for x, y in validation_loader:
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
val_loss += loss.item()
validation_losses.append(val_loss / len(validation_loader))
if validation_losses[-1] < best_val_loss:
best_val_loss = validation_losses[-1]
return {
"training_losses": training_losses,
"validation_losses": validation_losses,
"best_validation_loss": best_val_loss
}
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
model = create_model(trial)
result = train_and_eval(model, train_data, validation_data, batch_size)
return result["best_validation_loss"]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train-data", type=str, required=True)
parser.add_argument("--validation-data", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=128)
print(f"Running on device: {DEVICE}...")
parser = ArgumentParser()
parser.add_argument("--method", choices=["optuna", "train"], required=True)
parser.add_argument("--model-path", type=str, required=False)
args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = load_data(args.train_data)
validation_data = load_data(args.validation_data)
batch_size = args.batch_size
if args.method == "train":
dataset = EnWik9DataSet()
elif args.method == "optuna":
dataset = LoremIpsumDataset()
else:
raise ValueError(f"Unknown method: {args.method}")
print(f"training data length: {len(train_data)}")
print(f"validation data length: {len(validation_data)}")
print(f"batch size: {batch_size}")
dataset_length = len(dataset)
training_size = ceil(0.8 * dataset_length)
study = optuna.create_study(study_name="CNN network",direction="minimize")
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
data = dataset.data["text"]
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
[training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if args.model_path is not None:
model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200,
device=DEVICE
)

View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

View file

@ -8,7 +8,7 @@ class CausalConv1d(nn.Conv1d):
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
def forward(self, input: Tensor) -> Tensor:
return super().forward(input)
return super().forward(input)[:, :, :input.size(-1)]
class CNNPredictor(nn.Module):
def __init__(
@ -41,5 +41,5 @@ class CNNPredictor(nn.Module):
emdedding = emdedding.transpose(1, 2) # B, H, L
prediction = self.network(emdedding)
last_prediction = prediction[:, :, -1]
return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr.
return self.output_layer(last_prediction)

View file

@ -1,18 +0,0 @@
import optuna.trial as tr
from cnn import CNNPredictor
def create_model(trial: tr.Trial, vocab_size: int = 256, context_length: int = 128):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)

View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

View file

@ -0,0 +1,63 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from ..model.cnn import CNNPredictor
from train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, "models/final_model.pt")

View file

@ -0,0 +1,2 @@
from OptunaTrainer import OptunaTrainer
from trainer import Trainer

View file

@ -0,0 +1,50 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8
) -> tuple[list[float], list[float]]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for data in tqdm(training_loader):
optimizer.zero_grad()
x_hat = model(data)
loss = loss_fn(x_hat, data)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
with torch.no_grad():
losses = []
for data in validation_loader:
x_hat = model(data)
loss = loss_fn(x_hat, data)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
return avg_training_losses, avg_validation_losses

View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass

View file

@ -0,0 +1 @@
from .utils import *

View file

@ -14,6 +14,17 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
plt.show()
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
plt.plot(train_losses, label="Training loss")
plt.plot(validation_losses, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (cross entropy)")
plt.legend()
if show:
plt.show()
plt.savefig("losses.png")
def load_data(path: str) -> bytes:
with open(path, "rb") as f: