feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

View file

@ -22,7 +22,9 @@ def main():
data_root = args.data_root, data_root = args.data_root,
n_trials = 3 if args.debug else None, n_trials = 3 if args.debug else None,
size = 2**10 if args.debug else None, size = 2**10 if args.debug else None,
model_path = args.model_path model_name=args.model,
model_path = args.model_load_path,
model_out = args.model_save_path
) )
case 'compress': case 'compress':

View file

@ -15,8 +15,12 @@ def parse_arguments():
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
modelparser = ArgumentParser(add_help=False) modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model-path", type=str, required=False, modelparser.add_argument("--model", "-m", type=str, required=False,
help="Path to the model to load/save") help="Which model to use")
modelparser.add_argument("--model-load-path", type=str, required=False,
help="Filepath to the model to load")
modelparser.add_argument("--model-save-path", type=str, required=True,
help="Filepath to the model to save")
fileparser = ArgumentParser(add_help=False) fileparser = ArgumentParser(add_help=False)
fileparser.add_argument("--input-file", "-i", required=False, type=str) fileparser.add_argument("--input-file", "-i", required=False, type=str)

14
src/models/Model.py Normal file
View file

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from torch import nn
class Model(nn.Module, ABC):
@abstractmethod
def __init__(self, loss_function = None):
super().__init__()
self._loss_function = loss_function
@property
def loss_function(self):
return self._loss_function

View file

@ -1,2 +1,9 @@
from .Model import Model
from .cnn import CNNPredictor from .cnn import CNNPredictor
from .transformer import Transformer from .transformer import Transformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': Transformer
}

View file

@ -1,14 +1,16 @@
import torch
import torch.nn as nn import torch.nn as nn
class CNNPredictor(nn.Module): from src.models import Model
class CNNPredictor(Model):
def __init__( def __init__(
self, self,
vocab_size=256, vocab_size=256,
embed_dim=64, embed_dim=64,
hidden_dim=128, hidden_dim=128,
): ):
super().__init__() super().__init__(nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors # 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim) self.embed = nn.Embedding(vocab_size, embed_dim)

View file

@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
device=None, device=None,
dtype=None dtype=None
) )
self.loss_function = nn.CrossEntropyLoss()
def forward( def forward(
self, self,

View file

@ -1,7 +1,9 @@
import torch import torch
from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called from src.dataset_loaders import dataset_called
from src.models import model_called
from src.trainers import OptunaTrainer, Trainer, FullTrainer from src.trainers import OptunaTrainer, Trainer, FullTrainer
@ -13,10 +15,21 @@ def train(
size: int | None = None, size: int | None = None,
mode: str = "train", mode: str = "train",
method: str = 'optuna', method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None, model_path: str | None = None,
model_out: str | None = None
): ):
batch_size = 2 batch_size = 2
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name:
print("Creating model")
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path)
dataset_common_args = { dataset_common_args = {
'root': data_root, 'root': data_root,
'transform': lambda x: x.to(device), 'transform': lambda x: x.to(device),
@ -41,21 +54,16 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True) training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False) validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if model_path is not None:
print("Loading the models...")
model = torch.load(model_path)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer() trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
print("Training") print("Training")
trainer.execute( best_model = trainer.execute(
model=model, model=model,
train_loader=training_loader, train_loader=training_loader,
validation_loader=validation_loader, validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200, n_epochs=200,
device=device device=device
) )
print("Saving model...")
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")

View file

@ -1,26 +1,26 @@
from typing import Callable from torch import nn
import torch
from torch import nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .trainer import Trainer
from .train import train from .train import train
from .trainer import Trainer
from ..models import Model
from ..utils import print_losses from ..utils import print_losses
class FullTrainer(Trainer): class FullTrainer(Trainer):
def execute( def execute(
self, self,
model: nn.Module | None, model: Model,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int, n_epochs: int,
device: str device: str
) -> None: ) -> nn.Module:
if model is None: if model is None:
raise ValueError("Model must be provided: run optuna optimizations first") raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device) model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs) train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
print_losses(train_loss, val_loss) print_losses(train_loss, val_loss)
return model

View file

@ -3,60 +3,73 @@ from typing import Callable
import optuna import optuna
import optuna.trial as tr import optuna.trial as tr
import torch import torch
from torch import nn as nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .trainer import Trainer
from ..models.cnn import CNNPredictor
from .train import train from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, Transformer
def create_model(trial: tr.Trial, vocab_size: int = 256): def create_model(trial: tr.Trial, model: nn.Module):
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True) match model.__class__:
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True) case CNNPredictor.__class__:
return model(
return CNNPredictor( hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
vocab_size=vocab_size, embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
hidden_dim=hidden_dim, vocab_size=256,
embed_dim=embedding_dim, )
) case Transformer.__class__:
nhead = trial.suggest_int("nhead", 2, 8, log=True)
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
return model(
d_model=d_model,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
return None
def objective_function( def objective_function(
trial: tr.Trial, trial: tr.Trial,
training_loader: DataLoader, training_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], model: Model,
device: str device: str
): ):
model = create_model(trial).to(device) model = create_model(trial, model).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn) _, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
return min(validation_loss) return min(validation_loss)
class OptunaTrainer(Trainer): class OptunaTrainer(Trainer):
def __init__(self, n_trials: int | None = None): def __init__(self, n_trials: int | None = None):
super().__init__() super().__init__()
self.n_trials = n_trials if n_trials is not None else 20 self.n_trials = n_trials if n_trials else 20
print(f"Creating Optuna trainer(n_trials = {self.n_trials})") print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
def execute( def execute(
self, self,
model: nn.Module | None, model: Model,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int, n_epochs: int,
device: str device: str
) -> None: ) -> nn.Module:
study = optuna.create_study(study_name="CNN network", direction="minimize") study = optuna.create_study(direction="minimize")
study.optimize( study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
n_trials=self.n_trials n_trials=self.n_trials
) )
best_params = study.best_trial.params best_params = study.best_trial.params
best_model = CNNPredictor( best_model = model(
**best_params **best_params
) )
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
return best_model

View file

@ -1,7 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -15,8 +13,7 @@ class Trainer(ABC):
model: nn.Module | None, model: nn.Module | None,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int, n_epochs: int,
device: str device: str
) -> None: ) -> nn.Module:
pass pass