fix: Check paths

This commit is contained in:
Tibo De Peuter 2025-12-07 21:56:47 +01:00
parent ee4d94e157
commit 37ab43c134
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -1,3 +1,5 @@
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader
@ -61,9 +63,12 @@ def train(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
n_epochs=200,
n_epochs=n_trials,
device=device
)
print("Saving model...")
torch.save(best_model, model_out or f"saved_models/{model.__class__.__name__}.pt")
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
# Make sure path exists
Path(f).parent.mkdir(parents=True, exist_ok=True)
torch.save(best_model, f)