fix: Check paths
This commit is contained in:
parent
ee4d94e157
commit
37ab43c134
1 changed files with 7 additions and 2 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
|
@ -61,9 +63,12 @@ def train(
|
|||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
n_epochs=200,
|
||||
n_epochs=n_trials,
|
||||
device=device
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
|
||||
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
|
||||
# Make sure path exists
|
||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(best_model, f)
|
||||
|
|
|
|||
Reference in a new issue