fix: Runtime errors
This commit is contained in:
parent
abbf749029
commit
28ae8191ad
4 changed files with 20 additions and 12 deletions
3
main.py
3
main.py
|
|
@ -22,6 +22,7 @@ def main():
|
||||||
data_root = args.data_root,
|
data_root = args.data_root,
|
||||||
n_trials = 3 if args.debug else None,
|
n_trials = 3 if args.debug else None,
|
||||||
size = 2**10 if args.debug else None,
|
size = 2**10 if args.debug else None,
|
||||||
|
method = args.method,
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
model_path = args.model_load_path,
|
model_path = args.model_load_path,
|
||||||
model_out = args.model_save_path
|
model_out = args.model_save_path
|
||||||
|
|
@ -32,6 +33,8 @@ def main():
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
||||||
|
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
13
src/args.py
13
src/args.py
|
|
@ -29,18 +29,15 @@ def parse_arguments():
|
||||||
subparsers = parser.add_subparsers(dest="mode", required=True,
|
subparsers = parser.add_subparsers(dest="mode", required=True,
|
||||||
help="Mode to run in")
|
help="Mode to run in")
|
||||||
|
|
||||||
# TODO
|
train_parser = subparsers.add_parser("train",
|
||||||
fetch_parser = subparsers.add_parser("fetch", parents=[dataparser],
|
parents=[dataparser, modelparser],
|
||||||
help="Only fetch the dataset, then exit")
|
help="Do a full training")
|
||||||
|
train_parser.add_argument("--method",
|
||||||
train_parser = subparsers.add_parser("train", parents=[dataparser, modelparser])
|
choices=["fetch", "optuna", "full"], required=True,
|
||||||
train_parser.add_argument("--method", choices=["optuna", "full"], required=True,
|
|
||||||
help="Method to use for training")
|
help="Method to use for training")
|
||||||
|
|
||||||
# TODO
|
|
||||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||||
|
|
||||||
# TODO
|
|
||||||
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
||||||
|
|
||||||
return parser.parse_args(), parser.print_help
|
return parser.parse_args(), parser.print_help
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,12 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
def compress(
|
def compress(
|
||||||
input_file: str | None = None
|
device,
|
||||||
|
model_path: str,
|
||||||
|
output_file: str,
|
||||||
|
input_file: str | None = None
|
||||||
):
|
):
|
||||||
|
# Get input to compress
|
||||||
if input_file:
|
if input_file:
|
||||||
with open(input_file, "rb") as file:
|
with open(input_file, "rb") as file:
|
||||||
byte_data = file.read()
|
byte_data = file.read()
|
||||||
|
|
@ -15,6 +19,9 @@ def compress(
|
||||||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||||
print(tensor)
|
print(tensor)
|
||||||
|
|
||||||
|
# Get model
|
||||||
|
model = torch.load(model_path, weights_only=False)
|
||||||
|
|
||||||
# TODO Feed to model for compression, store result
|
# TODO Feed to model for compression, store result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ def train(
|
||||||
data_root: str,
|
data_root: str,
|
||||||
n_trials: int | None = None,
|
n_trials: int | None = None,
|
||||||
size: int | None = None,
|
size: int | None = None,
|
||||||
mode: str = "train",
|
|
||||||
method: str = 'optuna',
|
method: str = 'optuna',
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
model_path: str | None = None,
|
model_path: str | None = None,
|
||||||
|
|
@ -30,7 +29,7 @@ def train(
|
||||||
model = model_called[model_name]
|
model = model_called[model_name]
|
||||||
else:
|
else:
|
||||||
print("Loading model from disk")
|
print("Loading model from disk")
|
||||||
model = torch.load(model_path)
|
model = torch.load(model_path, weights_only=False)
|
||||||
|
|
||||||
dataset_common_args = {
|
dataset_common_args = {
|
||||||
'root': data_root,
|
'root': data_root,
|
||||||
|
|
@ -48,7 +47,7 @@ def train(
|
||||||
# TODO Allow to import arbitrary files
|
# TODO Allow to import arbitrary files
|
||||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||||
|
|
||||||
if mode == 'fetch':
|
if method == 'fetch':
|
||||||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
@ -72,3 +71,5 @@ def train(
|
||||||
# Make sure path exists
|
# Make sure path exists
|
||||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(best_model, f)
|
torch.save(best_model, f)
|
||||||
|
print(f"Saved model to '{f}'")
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue