fix: Runtime errors

This commit is contained in:
Tibo De Peuter 2025-12-08 11:14:36 +01:00
parent abbf749029
commit 28ae8191ad
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 20 additions and 12 deletions

View file

@ -22,6 +22,7 @@ def main():
data_root = args.data_root,
n_trials = 3 if args.debug else None,
size = 2**10 if args.debug else None,
method = args.method,
model_name=args.model,
model_path = args.model_load_path,
model_out = args.model_save_path
@ -32,6 +33,8 @@ def main():
case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
print("Done")
if __name__ == "__main__":

View file

@ -29,18 +29,15 @@ def parse_arguments():
subparsers = parser.add_subparsers(dest="mode", required=True,
help="Mode to run in")
# TODO
fetch_parser = subparsers.add_parser("fetch", parents=[dataparser],
help="Only fetch the dataset, then exit")
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
train_parser = subparsers.add_parser("train",
parents=[dataparser, modelparser],
help="Do a full training")
train_parser.add_argument("--method",
choices=["fetch", "optuna", "full"], required=True,
help="Method to use for training")
# TODO
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
# TODO
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
return parser.parse_args(), parser.print_help

View file

@ -2,8 +2,12 @@ import torch
def compress(
input_file: str | None = None
device,
model_path: str,
output_file: str,
input_file: str | None = None
):
# Get input to compress
if input_file:
with open(input_file, "rb") as file:
byte_data = file.read()
@ -15,6 +19,9 @@ def compress(
tensor = torch.tensor(list(byte_data), dtype=torch.long)
print(tensor)
# Get model
model = torch.load(model_path, weights_only=False)
# TODO Feed to model for compression, store result
return

View file

@ -15,7 +15,6 @@ def train(
data_root: str,
n_trials: int | None = None,
size: int | None = None,
mode: str = "train",
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
@ -30,7 +29,7 @@ def train(
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path)
model = torch.load(model_path, weights_only=False)
dataset_common_args = {
'root': data_root,
@ -48,7 +47,7 @@ def train(
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
if mode == 'fetch':
if method == 'fetch':
# TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0)
@ -72,3 +71,5 @@ def train(
# Make sure path exists
Path(f).parent.mkdir(parents=True, exist_ok=True)
torch.save(best_model, f)
print(f"Saved model to '{f}'")