fix: No hardcoding context len

This commit is contained in:
Tibo De Peuter 2025-12-11 23:23:55 +01:00
parent 4c8d603092
commit fc75ab51b0
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
2 changed files with 6 additions and 5 deletions

View file

@ -40,7 +40,8 @@ def main():
compress(device=device,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file
output_file=args.output_file,
context_length=args.context
)
case _:

View file

@ -1,20 +1,20 @@
import contextlib
from collections import deque
from decimal import Decimal
import torch
from pyae import ArithmeticEncoding
from tqdm import tqdm
from src.utils import reference_ae
def compress(
device,
model_path: str,
context_length: int = 128,
input_file: str | None = None,
output_file: str | None = None
):
# NOTE Hardcoded context length
context_length = 128
# Get input to compress
print("Reading input")
if input_file: