feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
4
main.py
4
main.py
|
|
@ -22,7 +22,9 @@ def main():
|
||||||
data_root = args.data_root,
|
data_root = args.data_root,
|
||||||
n_trials = 3 if args.debug else None,
|
n_trials = 3 if args.debug else None,
|
||||||
size = 2**10 if args.debug else None,
|
size = 2**10 if args.debug else None,
|
||||||
model_path = args.model_path
|
model_name=args.model,
|
||||||
|
model_path = args.model_load_path,
|
||||||
|
model_out = args.model_save_path
|
||||||
)
|
)
|
||||||
|
|
||||||
case 'compress':
|
case 'compress':
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,12 @@ def parse_arguments():
|
||||||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||||
|
|
||||||
modelparser = ArgumentParser(add_help=False)
|
modelparser = ArgumentParser(add_help=False)
|
||||||
modelparser.add_argument("--model-path", type=str, required=False,
|
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||||
help="Path to the model to load/save")
|
help="Which model to use")
|
||||||
|
modelparser.add_argument("--model-load-path", type=str, required=False,
|
||||||
|
help="Filepath to the model to load")
|
||||||
|
modelparser.add_argument("--model-save-path", type=str, required=True,
|
||||||
|
help="Filepath to the model to save")
|
||||||
|
|
||||||
fileparser = ArgumentParser(add_help=False)
|
fileparser = ArgumentParser(add_help=False)
|
||||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||||
|
|
|
||||||
14
src/models/Model.py
Normal file
14
src/models/Model.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module, ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, loss_function = None):
|
||||||
|
super().__init__()
|
||||||
|
self._loss_function = loss_function
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_function(self):
|
||||||
|
return self._loss_function
|
||||||
|
|
@ -1,2 +1,9 @@
|
||||||
|
from .Model import Model
|
||||||
from .cnn import CNNPredictor
|
from .cnn import CNNPredictor
|
||||||
from .transformer import Transformer
|
from .transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
model_called: dict[str, type[Model]] = {
|
||||||
|
'cnn': CNNPredictor,
|
||||||
|
'transformer': Transformer
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,16 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class CNNPredictor(nn.Module):
|
from src.models import Model
|
||||||
|
|
||||||
|
|
||||||
|
class CNNPredictor(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
embed_dim=64,
|
embed_dim=64,
|
||||||
hidden_dim=128,
|
hidden_dim=128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(nn.CrossEntropyLoss())
|
||||||
|
|
||||||
# 1. Embedding: maps bytes (0–255) → vectors
|
# 1. Embedding: maps bytes (0–255) → vectors
|
||||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None
|
dtype=None
|
||||||
)
|
)
|
||||||
|
self.loss_function = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
26
src/train.py
26
src/train.py
|
|
@ -1,7 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from src.dataset_loaders import dataset_called
|
from src.dataset_loaders import dataset_called
|
||||||
|
from src.models import model_called
|
||||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,10 +15,21 @@ def train(
|
||||||
size: int | None = None,
|
size: int | None = None,
|
||||||
mode: str = "train",
|
mode: str = "train",
|
||||||
method: str = 'optuna',
|
method: str = 'optuna',
|
||||||
|
model_name: str | None = None,
|
||||||
model_path: str | None = None,
|
model_path: str | None = None,
|
||||||
|
model_out: str | None = None
|
||||||
):
|
):
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|
||||||
|
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||||
|
|
||||||
|
if model_name:
|
||||||
|
print("Creating model")
|
||||||
|
model = model_called[model_name]
|
||||||
|
else:
|
||||||
|
print("Loading model from disk")
|
||||||
|
model = torch.load(model_path)
|
||||||
|
|
||||||
dataset_common_args = {
|
dataset_common_args = {
|
||||||
'root': data_root,
|
'root': data_root,
|
||||||
'transform': lambda x: x.to(device),
|
'transform': lambda x: x.to(device),
|
||||||
|
|
@ -41,21 +54,16 @@ def train(
|
||||||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
loss_fn = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
model = None
|
|
||||||
if model_path is not None:
|
|
||||||
print("Loading the models...")
|
|
||||||
model = torch.load(model_path)
|
|
||||||
|
|
||||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||||
|
|
||||||
print("Training")
|
print("Training")
|
||||||
trainer.execute(
|
best_model = trainer.execute(
|
||||||
model=model,
|
model=model,
|
||||||
train_loader=training_loader,
|
train_loader=training_loader,
|
||||||
validation_loader=validation_loader,
|
validation_loader=validation_loader,
|
||||||
loss_fn=loss_fn,
|
|
||||||
n_epochs=200,
|
n_epochs=200,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Saving model...")
|
||||||
|
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,26 @@
|
||||||
from typing import Callable
|
from torch import nn
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .trainer import Trainer
|
|
||||||
from .train import train
|
from .train import train
|
||||||
|
from .trainer import Trainer
|
||||||
|
from ..models import Model
|
||||||
from ..utils import print_losses
|
from ..utils import print_losses
|
||||||
|
|
||||||
|
|
||||||
class FullTrainer(Trainer):
|
class FullTrainer(Trainer):
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: nn.Module | None,
|
model: Model,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
||||||
n_epochs: int,
|
n_epochs: int,
|
||||||
device: str
|
device: str
|
||||||
) -> None:
|
) -> nn.Module:
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs)
|
||||||
print_losses(train_loss, val_loss)
|
print_losses(train_loss, val_loss)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
|
||||||
|
|
@ -3,60 +3,73 @@ from typing import Callable
|
||||||
import optuna
|
import optuna
|
||||||
import optuna.trial as tr
|
import optuna.trial as tr
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .trainer import Trainer
|
|
||||||
from ..models.cnn import CNNPredictor
|
|
||||||
from .train import train
|
from .train import train
|
||||||
|
from .trainer import Trainer
|
||||||
|
from ..models import Model, CNNPredictor, Transformer
|
||||||
|
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
def create_model(trial: tr.Trial, model: nn.Module):
|
||||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
match model.__class__:
|
||||||
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
case CNNPredictor.__class__:
|
||||||
|
return model(
|
||||||
return CNNPredictor(
|
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||||
vocab_size=vocab_size,
|
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||||
hidden_dim=hidden_dim,
|
vocab_size=256,
|
||||||
embed_dim=embedding_dim,
|
)
|
||||||
)
|
case Transformer.__class__:
|
||||||
|
nhead = trial.suggest_int("nhead", 2, 8, log=True)
|
||||||
|
d_model = trial.suggest_int("d_model", 64, 512, step=nhead)
|
||||||
|
return model(
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||||
|
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||||
|
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||||
|
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||||
|
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||||
|
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def objective_function(
|
def objective_function(
|
||||||
trial: tr.Trial,
|
trial: tr.Trial,
|
||||||
training_loader: DataLoader,
|
training_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
model: Model,
|
||||||
device: str
|
device: str
|
||||||
):
|
):
|
||||||
model = create_model(trial).to(device)
|
model = create_model(trial, model).to(device)
|
||||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function)
|
||||||
return min(validation_loss)
|
return min(validation_loss)
|
||||||
|
|
||||||
|
|
||||||
class OptunaTrainer(Trainer):
|
class OptunaTrainer(Trainer):
|
||||||
def __init__(self, n_trials: int | None = None):
|
def __init__(self, n_trials: int | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = n_trials if n_trials is not None else 20
|
self.n_trials = n_trials if n_trials else 20
|
||||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: nn.Module | None,
|
model: Model,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
||||||
n_epochs: int,
|
n_epochs: int,
|
||||||
device: str
|
device: str
|
||||||
) -> None:
|
) -> nn.Module:
|
||||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
study = optuna.create_study(direction="minimize")
|
||||||
study.optimize(
|
study.optimize(
|
||||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||||
n_trials=self.n_trials
|
n_trials=self.n_trials
|
||||||
)
|
)
|
||||||
|
|
||||||
best_params = study.best_trial.params
|
best_params = study.best_trial.params
|
||||||
best_model = CNNPredictor(
|
best_model = model(
|
||||||
**best_params
|
**best_params
|
||||||
)
|
)
|
||||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
|
||||||
|
return best_model
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
@ -15,8 +13,7 @@ class Trainer(ABC):
|
||||||
model: nn.Module | None,
|
model: nn.Module | None,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
||||||
n_epochs: int,
|
n_epochs: int,
|
||||||
device: str
|
device: str
|
||||||
) -> None:
|
) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Reference in a new issue