diff --git a/main.py b/main.py index ae4e709..7b1aab2 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,8 @@ def main(): compress(device=device, model_path=args.model_load_path, input_file=args.input_file, - output_file=args.output_file + output_file=args.output_file, + context_length=args.context ) case _: diff --git a/src/process.py b/src/process.py index a681960..dc5c479 100644 --- a/src/process.py +++ b/src/process.py @@ -1,20 +1,20 @@ +import contextlib from collections import deque from decimal import Decimal import torch -from pyae import ArithmeticEncoding from tqdm import tqdm +from src.utils import reference_ae + def compress( device, model_path: str, + context_length: int = 128, input_file: str | None = None, output_file: str | None = None ): - # NOTE Hardcoded context length - context_length = 128 - # Get input to compress print("Reading input") if input_file: