diff --git a/src/train.py b/src/train.py index 7e51617..63afa11 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,5 @@ +from pathlib import Path + import torch from torch import nn from torch.utils.data import DataLoader @@ -61,9 +63,12 @@ def train( model=model, train_loader=training_loader, validation_loader=validation_loader, - n_epochs=200, + n_epochs=n_trials, device=device ) print("Saving model...") - torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt") + f = model_out or f"saved_models/{model.__class__.__name__}.pt" + # Make sure path exists + Path(f).parent.mkdir(parents=True, exist_ok=True) + torch.save(best_model, f)