feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

View file

@ -22,7 +22,9 @@ def main():
data_root = args.data_root,
n_trials = 3 if args.debug else None,
size = 2**10 if args.debug else None,
model_path = args.model_path
model_name=args.model,
model_path = args.model_load_path,
model_out = args.model_save_path
)
case 'compress':

View file

@ -15,8 +15,12 @@ def parse_arguments():
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model-path", type=str, required=False,
help="Path to the model to load/save")
modelparser.add_argument("--model", "-m", type=str, required=False,
help="Which model to use")
modelparser.add_argument("--model-load-path", type=str, required=False,
help="Filepath to the model to load")
modelparser.add_argument("--model-save-path", type=str, required=True,
help="Filepath to the model to save")
fileparser = ArgumentParser(add_help=False)
fileparser.add_argument("--input-file", "-i", required=False, type=str)

14
src/models/Model.py Normal file
View file

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from torch import nn
class Model(nn.Module, ABC):
@abstractmethod
def __init__(self, loss_function = None):
super().__init__()
self._loss_function = loss_function
@property
def loss_function(self):
return self._loss_function

View file

@ -1,2 +1,9 @@
from .Model import Model
from .cnn import CNNPredictor
from .transformer import Transformer
from .transformer import Transformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': Transformer
}

View file

@ -1,14 +1,16 @@
import torch
import torch.nn as nn
class CNNPredictor(nn.Module):
from src.models import Model
class CNNPredictor(Model):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__()
super().__init__(nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)

View file

@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
device=None,
dtype=None
)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,

View file

@ -1,7 +1,9 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called
from src.models import model_called
from src.trainers import OptunaTrainer, Trainer, FullTrainer
@ -13,10 +15,21 @@ def train(
size: int | None = None,
mode: str = "train",
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
):
batch_size = 2
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name:
print("Creating model")
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path)
dataset_common_args = {
'root': data_root,
'transform': lambda x: x.to(device),
@ -41,21 +54,16 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if model_path is not None:
print("Loading the models...")
model = torch.load(model_path)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
print("Training")
trainer.execute(
best_model = trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200,
device=device
)
print("Saving model...")
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")

View file

@ -1,26 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch import nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train
from .trainer import Trainer
from ..models import Model
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
) -> nn.Module:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
print_losses(train_loss, val_loss)
return model

View file

@ -3,60 +3,73 @@ from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch import nn
from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer
def create_model(trial: tr.Trial, vocab_size: int = 256):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
embed_dim=embedding_dim,
)
def create_model(trial: tr.Trial, model: nn.Module):
match model.__class__:
case CNNPredictor.__class__:
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
case Transformer.__class__:
nhead = trial.suggest_int("nhead", 2, 8, log=True)
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
return model(
d_model=d_model,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
model: Model,
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
return min(validation_loss)
class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None):
super().__init__()
self.n_trials = n_trials if n_trials is not None else 20
self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute(
self,
model: nn.Module | None,
model: Model,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
) -> nn.Module:
study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials
)
best_params = study.best_trial.params
best_model = CNNPredictor(
best_model = model(
**best_params
)
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
return best_model

View file

@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
@ -15,8 +13,7 @@ class Trainer(ABC):
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass
) -> nn.Module:
pass