diff --git a/src/train.py b/src/train.py index e2bcabb..ee4a99a 100644 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,6 @@ from pathlib import Path import torch -from torch import nn from torch.utils.data import DataLoader from src.dataset_loaders import dataset_called diff --git a/src/trainers/FullTrainer.py b/src/trainers/FullTrainer.py index d3e8bde..cfe9b08 100644 --- a/src/trainers/FullTrainer.py +++ b/src/trainers/FullTrainer.py @@ -13,7 +13,7 @@ class FullTrainer(Trainer): model: Model, train_loader: DataLoader, validation_loader: DataLoader, - n_epochs: int, + n_epochs: int | None, device: str ) -> nn.Module: if model is None: diff --git a/src/trainers/train.py b/src/trainers/train.py index d26e7de..61a6d09 100644 --- a/src/trainers/train.py +++ b/src/trainers/train.py @@ -26,7 +26,7 @@ def train( training_loader: DataLoader, validation_loader: DataLoader, loss_fn: Callable, - epochs: int = 100, + epochs: int | None = None, learning_rate: float = 1e-3, weight_decay: float = 1e-8, device="cuda" @@ -37,6 +37,9 @@ def train( avg_training_losses = [] avg_validation_losses = [] + if epochs is None: + epochs = 100 + for epoch in range(epochs): model.train() diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 228f924..19e6480 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -13,7 +13,7 @@ class Trainer(ABC): model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, - n_epochs: int, + n_epochs: int | None, device: str ) -> nn.Module: pass