feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
|
|
@ -15,8 +15,12 @@ def parse_arguments():
|
|||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model-path", type=str, required=False,
|
||||
help="Path to the model to load/save")
|
||||
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||
help="Which model to use")
|
||||
modelparser.add_argument("--model-load-path", type=str, required=False,
|
||||
help="Filepath to the model to load")
|
||||
modelparser.add_argument("--model-save-path", type=str, required=True,
|
||||
help="Filepath to the model to save")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||
|
|
|
|||
14
src/models/Model.py
Normal file
14
src/models/Model.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, loss_function = None):
|
||||
super().__init__()
|
||||
self._loss_function = loss_function
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
return self._loss_function
|
||||
|
|
@ -1,2 +1,9 @@
|
|||
from .Model import Model
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import Transformer
|
||||
from .transformer import Transformer
|
||||
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': Transformer
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class CNNPredictor(Model):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(nn.CrossEntropyLoss())
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
|
|||
device=None,
|
||||
dtype=None
|
||||
)
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
26
src/train.py
26
src/train.py
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
from src.models import model_called
|
||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
|
||||
|
|
@ -13,10 +15,21 @@ def train(
|
|||
size: int | None = None,
|
||||
mode: str = "train",
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
):
|
||||
batch_size = 2
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
if model_name:
|
||||
print("Creating model")
|
||||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
'transform': lambda x: x.to(device),
|
||||
|
|
@ -41,21 +54,16 @@ def train(
|
|||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
if model_path is not None:
|
||||
print("Loading the models...")
|
||||
model = torch.load(model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||
|
||||
print("Training")
|
||||
trainer.execute(
|
||||
best_model = trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
loss_fn=loss_fn,
|
||||
n_epochs=200,
|
||||
device=device
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
|
||||
|
|
|
|||
|
|
@ -1,26 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model
|
||||
from ..utils import print_losses
|
||||
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
) -> nn.Module:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -3,60 +3,73 @@ from typing import Callable
|
|||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from ..models.cnn import CNNPredictor
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, Transformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=hidden_dim,
|
||||
embed_dim=embedding_dim,
|
||||
)
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
match model.__class__:
|
||||
case CNNPredictor.__class__:
|
||||
return model(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case Transformer.__class__:
|
||||
nhead = trial.suggest_int("nhead", 2, 8, log=True)
|
||||
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
|
||||
return model(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
model: Model,
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials is not None else 20
|
||||
self.n_trials = n_trials if n_trials else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
) -> nn.Module:
|
||||
study = optuna.create_study(direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
best_model = model(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||
|
||||
return best_model
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
|
@ -15,8 +13,7 @@ class Trainer(ABC):
|
|||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
) -> nn.Module:
|
||||
pass
|
||||
|
|
|
|||
Reference in a new issue