code cleanup
This commit is contained in:
parent
ea9cf12db0
commit
73d1742cbd
44 changed files with 6 additions and 2835 deletions
22
trainers/trainer.py
Normal file
22
trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
Reference in a new issue