feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
|
|
@ -1,7 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
|
@ -15,8 +13,7 @@ class Trainer(ABC):
|
|||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
) -> nn.Module:
|
||||
pass
|
||||
|
|
|
|||
Reference in a new issue