from abc import ABC, abstractmethod from typing import Callable import torch import torch.nn as nn from torch.utils.data import DataLoader class Trainer(ABC): """Abstract class for trainers.""" @abstractmethod def execute( self, model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str ) -> None: pass