fix: Runtime errors
This commit is contained in:
parent
abbf749029
commit
28ae8191ad
4 changed files with 20 additions and 12 deletions
|
|
@ -15,7 +15,6 @@ def train(
|
|||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
mode: str = "train",
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
|
|
@ -30,7 +29,7 @@ def train(
|
|||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path)
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
|
|
@ -48,7 +47,7 @@ def train(
|
|||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
||||
if mode == 'fetch':
|
||||
if method == 'fetch':
|
||||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
|
|
@ -72,3 +71,5 @@ def train(
|
|||
# Make sure path exists
|
||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(best_model, f)
|
||||
print(f"Saved model to '{f}'")
|
||||
|
||||
|
|
|
|||
Reference in a new issue