feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
22
CNN-model/trainers/trainer.py
Normal file
22
CNN-model/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
Reference in a new issue