fix: Runtime errors

This commit is contained in:
Tibo De Peuter 2025-12-08 11:14:36 +01:00
parent abbf749029
commit 28ae8191ad
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 20 additions and 12 deletions

View file

@ -15,7 +15,6 @@ def train(
data_root: str,
n_trials: int | None = None,
size: int | None = None,
mode: str = "train",
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
@ -30,7 +29,7 @@ def train(
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path)
model = torch.load(model_path, weights_only=False)
dataset_common_args = {
'root': data_root,
@ -48,7 +47,7 @@ def train(
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
if mode == 'fetch':
if method == 'fetch':
# TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0)
@ -72,3 +71,5 @@ def train(
# Make sure path exists
Path(f).parent.mkdir(parents=True, exist_ok=True)
torch.save(best_model, f)
print(f"Saved model to '{f}'")