fix: Runtime errors
This commit is contained in:
parent
abbf749029
commit
28ae8191ad
4 changed files with 20 additions and 12 deletions
3
main.py
3
main.py
|
|
@ -22,6 +22,7 @@ def main():
|
|||
data_root = args.data_root,
|
||||
n_trials = 3 if args.debug else None,
|
||||
size = 2**10 if args.debug else None,
|
||||
method = args.method,
|
||||
model_name=args.model,
|
||||
model_path = args.model_load_path,
|
||||
model_out = args.model_save_path
|
||||
|
|
@ -33,6 +34,8 @@ def main():
|
|||
case _:
|
||||
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
13
src/args.py
13
src/args.py
|
|
@ -29,18 +29,15 @@ def parse_arguments():
|
|||
subparsers = parser.add_subparsers(dest="mode", required=True,
|
||||
help="Mode to run in")
|
||||
|
||||
# TODO
|
||||
fetch_parser = subparsers.add_parser("fetch", parents=[dataparser],
|
||||
help="Only fetch the dataset, then exit")
|
||||
|
||||
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
|
||||
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
|
||||
train_parser = subparsers.add_parser("train",
|
||||
parents=[dataparser, modelparser],
|
||||
help="Do a full training")
|
||||
train_parser.add_argument("--method",
|
||||
choices=["fetch", "optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
|
||||
# TODO
|
||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||
|
||||
# TODO
|
||||
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
||||
|
||||
return parser.parse_args(), parser.print_help
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@ import torch
|
|||
|
||||
|
||||
def compress(
|
||||
input_file: str | None = None
|
||||
device,
|
||||
model_path: str,
|
||||
output_file: str,
|
||||
input_file: str | None = None
|
||||
):
|
||||
# Get input to compress
|
||||
if input_file:
|
||||
with open(input_file, "rb") as file:
|
||||
byte_data = file.read()
|
||||
|
|
@ -15,6 +19,9 @@ def compress(
|
|||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
print(tensor)
|
||||
|
||||
# Get model
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
# TODO Feed to model for compression, store result
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ def train(
|
|||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
mode: str = "train",
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
|
|
@ -30,7 +29,7 @@ def train(
|
|||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path)
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
|
|
@ -48,7 +47,7 @@ def train(
|
|||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
||||
if mode == 'fetch':
|
||||
if method == 'fetch':
|
||||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
|
|
@ -72,3 +71,5 @@ def train(
|
|||
# Make sure path exists
|
||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(best_model, f)
|
||||
print(f"Saved model to '{f}'")
|
||||
|
||||
|
|
|
|||
Reference in a new issue