feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

View file

@ -0,0 +1,63 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from ..model.cnn import CNNPredictor
from train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, "models/final_model.pt")

View file

@ -0,0 +1,2 @@
from OptunaTrainer import OptunaTrainer
from trainer import Trainer

View file

@ -0,0 +1,50 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8
) -> tuple[list[float], list[float]]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for data in tqdm(training_loader):
optimizer.zero_grad()
x_hat = model(data)
loss = loss_fn(x_hat, data)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
with torch.no_grad():
losses = []
for data in validation_loader:
x_hat = model(data)
loss = loss_fn(x_hat, data)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
return avg_training_losses, avg_validation_losses

View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass