From 6c5908e6aed0579b866c1130d09d8afa6a30b9b0 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 15:37:54 +0100 Subject: [PATCH] fix: NoneType epochs --- src/train.py | 1 - src/trainers/FullTrainer.py | 2 +- src/trainers/train.py | 5 ++++- src/trainers/trainer.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index e2bcabb..ee4a99a 100644 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,6 @@ from pathlib import Path import torch -from torch import nn from torch.utils.data import DataLoader from src.dataset_loaders import dataset_called diff --git a/src/trainers/FullTrainer.py b/src/trainers/FullTrainer.py index d3e8bde..cfe9b08 100644 --- a/src/trainers/FullTrainer.py +++ b/src/trainers/FullTrainer.py @@ -13,7 +13,7 @@ class FullTrainer(Trainer): model: Model, train_loader: DataLoader, validation_loader: DataLoader, - n_epochs: int, + n_epochs: int | None, device: str ) -> nn.Module: if model is None: diff --git a/src/trainers/train.py b/src/trainers/train.py index d26e7de..61a6d09 100644 --- a/src/trainers/train.py +++ b/src/trainers/train.py @@ -26,7 +26,7 @@ def train( training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable, - epochs: int = 100, + epochs: int | None = None, learning_rate: float = 1e-3, weight_decay: float = 1e-8, device="cuda" @@ -37,6 +37,9 @@ def train( avg_training_losses = [] avg_validation_losses = [] + if epochs is None: + epochs = 100 + for epoch in range(epochs): model.train() diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 228f924..19e6480 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -13,7 +13,7 @@ class Trainer(ABC): model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, - n_epochs: int, + n_epochs: int | None, device: str ) -> nn.Module: pass