from typing import Callable import torch from torch import nn as nn from torch.utils.data import DataLoader from .trainer import Trainer from .train import train from ..utils import print_losses class FullTrainer(Trainer): def execute( self, model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str ) -> None: if model is None: raise ValueError("Model must be provided: run optuna optimizations first") model.to(device) train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs) print_losses(train_loss, val_loss)