fix: NoneType epochs
This commit is contained in:
parent
63119980c9
commit
6c5908e6ae
4 changed files with 6 additions and 4 deletions
|
|
@ -1,7 +1,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from src.dataset_loaders import dataset_called
|
from src.dataset_loaders import dataset_called
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class FullTrainer(Trainer):
|
||||||
model: Model,
|
model: Model,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
n_epochs: int,
|
n_epochs: int | None,
|
||||||
device: str
|
device: str
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ def train(
|
||||||
training_loader: DataLoader,
|
training_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable,
|
loss_fn: Callable,
|
||||||
epochs: int = 100,
|
epochs: int | None = None,
|
||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
weight_decay: float = 1e-8,
|
weight_decay: float = 1e-8,
|
||||||
device="cuda"
|
device="cuda"
|
||||||
|
|
@ -37,6 +37,9 @@ def train(
|
||||||
avg_training_losses = []
|
avg_training_losses = []
|
||||||
avg_validation_losses = []
|
avg_validation_losses = []
|
||||||
|
|
||||||
|
if epochs is None:
|
||||||
|
epochs = 100
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class Trainer(ABC):
|
||||||
model: nn.Module | None,
|
model: nn.Module | None,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
n_epochs: int,
|
n_epochs: int | None,
|
||||||
device: str
|
device: str
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Reference in a new issue