feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
26
src/train.py
26
src/train.py
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
from src.models import model_called
|
||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
|
||||
|
|
@ -13,10 +15,21 @@ def train(
|
|||
size: int | None = None,
|
||||
mode: str = "train",
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
):
|
||||
batch_size = 2
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
if model_name:
|
||||
print("Creating model")
|
||||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
'transform': lambda x: x.to(device),
|
||||
|
|
@ -41,21 +54,16 @@ def train(
|
|||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
if model_path is not None:
|
||||
print("Loading the models...")
|
||||
model = torch.load(model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||
|
||||
print("Training")
|
||||
trainer.execute(
|
||||
best_model = trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
loss_fn=loss_fn,
|
||||
n_epochs=200,
|
||||
device=device
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
|
||||
|
|
|
|||
Reference in a new issue