feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

View file

@ -1,7 +1,9 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called
from src.models import model_called
from src.trainers import OptunaTrainer, Trainer, FullTrainer
@ -13,10 +15,21 @@ def train(
size: int | None = None,
mode: str = "train",
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
):
batch_size = 2
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name:
print("Creating model")
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path)
dataset_common_args = {
'root': data_root,
'transform': lambda x: x.to(device),
@ -41,21 +54,16 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if model_path is not None:
print("Loading the models...")
model = torch.load(model_path)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
print("Training")
trainer.execute(
best_model = trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200,
device=device
)
print("Saving model...")
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")