fix: NoneType epochs

This commit is contained in:
Tibo De Peuter 2025-12-10 15:37:54 +01:00
parent 63119980c9
commit 6c5908e6ae
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 6 additions and 4 deletions

View file

@ -1,7 +1,6 @@
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called

View file

@ -13,7 +13,7 @@ class FullTrainer(Trainer):
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int,
n_epochs: int | None,
device: str
) -> nn.Module:
if model is None:

View file

@ -26,7 +26,7 @@ def train(
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable,
epochs: int = 100,
epochs: int | None = None,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8,
device="cuda"
@ -37,6 +37,9 @@ def train(
avg_training_losses = []
avg_validation_losses = []
if epochs is None:
epochs = 100
for epoch in range(epochs):
model.train()

View file

@ -13,7 +13,7 @@ class Trainer(ABC):
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int,
n_epochs: int | None,
device: str
) -> nn.Module:
pass