fix: NoneType epochs

This commit is contained in:
Tibo De Peuter 2025-12-10 15:37:54 +01:00
parent 63119980c9
commit 6c5908e6ae
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 6 additions and 4 deletions

View file

@ -1,7 +1,6 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called from src.dataset_loaders import dataset_called

View file

@ -13,7 +13,7 @@ class FullTrainer(Trainer):
model: Model, model: Model,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int, n_epochs: int | None,
device: str device: str
) -> nn.Module: ) -> nn.Module:
if model is None: if model is None:

View file

@ -26,7 +26,7 @@ def train(
training_loader: DataLoader, training_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable, loss_fn: Callable,
epochs: int = 100, epochs: int | None = None,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
weight_decay: float = 1e-8, weight_decay: float = 1e-8,
device="cuda" device="cuda"
@ -37,6 +37,9 @@ def train(
avg_training_losses = [] avg_training_losses = []
avg_validation_losses = [] avg_validation_losses = []
if epochs is None:
epochs = 100
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()

View file

@ -13,7 +13,7 @@ class Trainer(ABC):
model: nn.Module | None, model: nn.Module | None,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int, n_epochs: int | None,
device: str device: str
) -> nn.Module: ) -> nn.Module:
pass pass