feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass