chore: Restructure

This commit is contained in:
Tibo De Peuter 2025-12-05 12:37:48 +01:00
parent 8b6c4e17ab
commit f32f4678e1
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
62 changed files with 0 additions and 10547 deletions

View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)