fix: No hardcoding context len

This commit is contained in:
Tibo De Peuter 2025-12-11 23:23:55 +01:00
parent 4c8d603092
commit fc75ab51b0
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
2 changed files with 6 additions and 5 deletions

View file

@ -40,7 +40,8 @@ def main():
compress(device=device, compress(device=device,
model_path=args.model_load_path, model_path=args.model_load_path,
input_file=args.input_file, input_file=args.input_file,
output_file=args.output_file output_file=args.output_file,
context_length=args.context
) )
case _: case _:

View file

@ -1,20 +1,20 @@
import contextlib
from collections import deque from collections import deque
from decimal import Decimal from decimal import Decimal
import torch import torch
from pyae import ArithmeticEncoding
from tqdm import tqdm from tqdm import tqdm
from src.utils import reference_ae
def compress( def compress(
device, device,
model_path: str, model_path: str,
context_length: int = 128,
input_file: str | None = None, input_file: str | None = None,
output_file: str | None = None output_file: str | None = None
): ):
# NOTE Hardcoded context length
context_length = 128
# Get input to compress # Get input to compress
print("Reading input") print("Reading input")
if input_file: if input_file: