chore: Restructure
This commit is contained in:
parent
8b6c4e17ab
commit
f32f4678e1
62 changed files with 0 additions and 10547 deletions
26
src/trainers/FullTrainer.py
Normal file
26
src/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from .train import train
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
Reference in a new issue