22 lines
No EOL
528 B
Python
22 lines
No EOL
528 B
Python
from abc import ABC, abstractmethod
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class Trainer(ABC):
|
|
"""Abstract class for trainers."""
|
|
|
|
@abstractmethod
|
|
def execute(
|
|
self,
|
|
model: nn.Module | None,
|
|
train_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
n_epochs: int,
|
|
device: str
|
|
) -> None:
|
|
pass |