26 lines
No EOL
797 B
Python
26 lines
No EOL
797 B
Python
from typing import Callable
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from .trainer import Trainer
|
|
from .train import train
|
|
from ..utils import print_losses
|
|
|
|
class FullTrainer(Trainer):
|
|
def execute(
|
|
self,
|
|
model: nn.Module | None,
|
|
train_loader: DataLoader,
|
|
validation_loader: DataLoader,
|
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
n_epochs: int,
|
|
device: str
|
|
) -> None:
|
|
if model is None:
|
|
raise ValueError("Model must be provided: run optuna optimizations first")
|
|
|
|
model.to(device)
|
|
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
|
print_losses(train_loss, val_loss) |