fix: NoneType epochs
This commit is contained in:
parent
63119980c9
commit
6c5908e6ae
4 changed files with 6 additions and 4 deletions
|
|
@ -1,7 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class FullTrainer(Trainer):
|
|||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
if model is None:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def train(
|
|||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable,
|
||||
epochs: int = 100,
|
||||
epochs: int | None = None,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
device="cuda"
|
||||
|
|
@ -37,6 +37,9 @@ def train(
|
|||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
if epochs is None:
|
||||
epochs = 100
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
model.train()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class Trainer(ABC):
|
|||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
pass
|
||||
|
|
|
|||
Reference in a new issue