fix: Check paths
This commit is contained in:
parent
ee4d94e157
commit
37ab43c134
1 changed files with 7 additions and 2 deletions
|
|
@ -1,3 +1,5 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
@ -61,9 +63,12 @@ def train(
|
||||||
model=model,
|
model=model,
|
||||||
train_loader=training_loader,
|
train_loader=training_loader,
|
||||||
validation_loader=validation_loader,
|
validation_loader=validation_loader,
|
||||||
n_epochs=200,
|
n_epochs=n_trials,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Saving model...")
|
print("Saving model...")
|
||||||
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
|
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
|
||||||
|
# Make sure path exists
|
||||||
|
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(best_model, f)
|
||||||
|
|
|
||||||
Reference in a new issue