diff --git a/main.py b/main.py index db9147b..ee78c21 100644 --- a/main.py +++ b/main.py @@ -22,7 +22,9 @@ def main(): data_root = args.data_root, n_trials = 3 if args.debug else None, size = 2**10 if args.debug else None, - model_path = args.model_path + model_name=args.model, + model_path = args.model_load_path, + model_out = args.model_save_path ) case 'compress': diff --git a/src/args.py b/src/args.py index 55a72c2..2319352 100644 --- a/src/args.py +++ b/src/args.py @@ -15,8 +15,12 @@ def parse_arguments(): dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True) modelparser = ArgumentParser(add_help=False) - modelparser.add_argument("--model-path", type=str, required=False, - help="Path to the model to load/save") + modelparser.add_argument("--model", "-m", type=str, required=False, + help="Which model to use") + modelparser.add_argument("--model-load-path", type=str, required=False, + help="Filepath to the model to load") + modelparser.add_argument("--model-save-path", type=str, required=True, + help="Filepath to the model to save") fileparser = ArgumentParser(add_help=False) fileparser.add_argument("--input-file", "-i", required=False, type=str) diff --git a/src/models/Model.py b/src/models/Model.py new file mode 100644 index 0000000..af8d5d3 --- /dev/null +++ b/src/models/Model.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + +from torch import nn + + +class Model(nn.Module, ABC): + @abstractmethod + def __init__(self, loss_function = None): + super().__init__() + self._loss_function = loss_function + + @property + def loss_function(self): + return self._loss_function diff --git a/src/models/__init__.py b/src/models/__init__.py index 42e6d4c..e7e29b4 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,2 +1,9 @@ +from .Model import Model from .cnn import CNNPredictor -from .transformer import Transformer \ No newline at end of file +from .transformer import Transformer + + +model_called: dict[str, type[Model]] = { + 'cnn': CNNPredictor, + 'transformer': Transformer +} diff --git a/src/models/cnn/cnn.py b/src/models/cnn/cnn.py index 05768d7..22e8843 100644 --- a/src/models/cnn/cnn.py +++ b/src/models/cnn/cnn.py @@ -1,14 +1,16 @@ -import torch import torch.nn as nn -class CNNPredictor(nn.Module): +from src.models import Model + + +class CNNPredictor(Model): def __init__( self, vocab_size=256, embed_dim=64, hidden_dim=128, ): - super().__init__() + super().__init__(nn.CrossEntropyLoss()) # 1. Embedding: maps bytes (0–255) → vectors self.embed = nn.Embedding(vocab_size, embed_dim) diff --git a/src/models/transformer/transformer.py b/src/models/transformer/transformer.py index 63032eb..774ae7c 100644 --- a/src/models/transformer/transformer.py +++ b/src/models/transformer/transformer.py @@ -30,6 +30,7 @@ class Transformer(nn.Transformer): device=None, dtype=None ) + self.loss_function = nn.CrossEntropyLoss() def forward( self, diff --git a/src/train.py b/src/train.py index 1faffb9..7e51617 100644 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,9 @@ import torch +from torch import nn from torch.utils.data import DataLoader from src.dataset_loaders import dataset_called +from src.models import model_called from src.trainers import OptunaTrainer, Trainer, FullTrainer @@ -13,10 +15,21 @@ def train( size: int | None = None, mode: str = "train", method: str = 'optuna', + model_name: str | None = None, model_path: str | None = None, + model_out: str | None = None ): batch_size = 2 + assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" + + if model_name: + print("Creating model") + model = model_called[model_name] + else: + print("Loading model from disk") + model = torch.load(model_path) + dataset_common_args = { 'root': data_root, 'transform': lambda x: x.to(device), @@ -41,21 +54,16 @@ def train( training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True) validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False) - loss_fn = torch.nn.CrossEntropyLoss() - - model = None - if model_path is not None: - print("Loading the models...") - model = torch.load(model_path) - trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer() print("Training") - trainer.execute( + best_model = trainer.execute( model=model, train_loader=training_loader, validation_loader=validation_loader, - loss_fn=loss_fn, n_epochs=200, device=device ) + + print("Saving model...") + torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt") diff --git a/src/trainers/FullTrainer.py b/src/trainers/FullTrainer.py index 7f7882a..cb0dc2c 100644 --- a/src/trainers/FullTrainer.py +++ b/src/trainers/FullTrainer.py @@ -1,26 +1,26 @@ -from typing import Callable - -import torch -from torch import nn as nn +from torch import nn from torch.utils.data import DataLoader -from .trainer import Trainer from .train import train +from .trainer import Trainer +from ..models import Model from ..utils import print_losses + class FullTrainer(Trainer): def execute( self, - model: nn.Module | None, + model: Model, train_loader: DataLoader, validation_loader: DataLoader, - loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str - ) -> None: + ) -> nn.Module: if model is None: raise ValueError("Model must be provided: run optuna optimizations first") model.to(device) - train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs) - print_losses(train_loss, val_loss) \ No newline at end of file + train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs) + print_losses(train_loss, val_loss) + + return model diff --git a/src/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py index eb896fc..e9a7417 100644 --- a/src/trainers/OptunaTrainer.py +++ b/src/trainers/OptunaTrainer.py @@ -3,60 +3,73 @@ from typing import Callable import optuna import optuna.trial as tr import torch -from torch import nn as nn +from torch import nn from torch.utils.data import DataLoader -from .trainer import Trainer -from ..models.cnn import CNNPredictor from .train import train +from .trainer import Trainer +from ..models import Model, CNNPredictor, Transformer -def create_model(trial: tr.Trial, vocab_size: int = 256): - hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True) - embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True) - - return CNNPredictor( - vocab_size=vocab_size, - hidden_dim=hidden_dim, - embed_dim=embedding_dim, - ) +def create_model(trial: tr.Trial, model: nn.Module): + match model.__class__: + case CNNPredictor.__class__: + return model( + hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True), + embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), + vocab_size=256, + ) + case Transformer.__class__: + nhead = trial.suggest_int("nhead", 2, 8, log=True) + d_model = trial.suggest_int("d_model", 64, 512, step=nhead) + return model( + d_model=d_model, + nhead=nhead, + num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), + num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), + dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True), + dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True), + activation=trial.suggest_categorical("activation", ["relu", "gelu"]), + layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True), + ) + return None def objective_function( trial: tr.Trial, training_loader: DataLoader, validation_loader: DataLoader, - loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + model: Model, device: str ): - model = create_model(trial).to(device) - _, validation_loss = train(model, training_loader, validation_loader, loss_fn) + model = create_model(trial, model).to(device) + _, validation_loss = train(model, training_loader, validation_loader, model.loss_function) return min(validation_loss) class OptunaTrainer(Trainer): def __init__(self, n_trials: int | None = None): super().__init__() - self.n_trials = n_trials if n_trials is not None else 20 + self.n_trials = n_trials if n_trials else 20 print(f"Creating Optuna trainer(n_trials = {self.n_trials})") def execute( self, - model: nn.Module | None, + model: Model, train_loader: DataLoader, validation_loader: DataLoader, - loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str - ) -> None: - study = optuna.create_study(study_name="CNN network", direction="minimize") + ) -> nn.Module: + study = optuna.create_study(direction="minimize") study.optimize( - lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device), + lambda trial: objective_function(trial, train_loader, validation_loader, model, device), n_trials=self.n_trials ) best_params = study.best_trial.params - best_model = CNNPredictor( + best_model = model( **best_params ) - torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt") + + return best_model diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 8543589..228f924 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable -import torch import torch.nn as nn from torch.utils.data import DataLoader @@ -15,8 +13,7 @@ class Trainer(ABC): model: nn.Module | None, train_loader: DataLoader, validation_loader: DataLoader, - loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], n_epochs: int, device: str - ) -> None: - pass \ No newline at end of file + ) -> nn.Module: + pass