feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
26
CNN-model/trainers/FullTrainer.py
Normal file
26
CNN-model/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer import Trainer
|
||||
from train import train
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
63
CNN-model/trainers/OptunaTrainer.py
Normal file
63
CNN-model/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from typing import Callable
|
||||
|
||||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer import Trainer
|
||||
from ..model.cnn import CNNPredictor
|
||||
from train import train
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
num_layers = trial.suggest_int("num_layers", 1, 6)
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
||||
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
||||
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
hidden_dim=hidden_dim,
|
||||
kernel_size=kernel_size,
|
||||
dropout_prob=dropout_prob,
|
||||
use_batchnorm=use_batchnorm
|
||||
)
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=20
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, "models/final_model.pt")
|
||||
2
CNN-model/trainers/__init__.py
Normal file
2
CNN-model/trainers/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from OptunaTrainer import OptunaTrainer
|
||||
from trainer import Trainer
|
||||
50
CNN-model/trainers/train.py
Normal file
50
CNN-model/trainers/train.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def train(
|
||||
model: nn.Module,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
epochs: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8
|
||||
) -> tuple[list[float], list[float]]:
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for data in tqdm(training_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
x_hat = model(data)
|
||||
|
||||
loss = loss_fn(x_hat, data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss.append(loss.item())
|
||||
|
||||
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||
|
||||
with torch.no_grad():
|
||||
losses = []
|
||||
for data in validation_loader:
|
||||
x_hat = model(data)
|
||||
loss = loss_fn(x_hat, data)
|
||||
losses.append(loss.item())
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
avg_validation_losses.append(avg_loss)
|
||||
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
|
||||
|
||||
return avg_training_losses, avg_validation_losses
|
||||
|
||||
22
CNN-model/trainers/trainer.py
Normal file
22
CNN-model/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
Reference in a new issue