fix: No hardcoding context len
This commit is contained in:
parent
4c8d603092
commit
fc75ab51b0
2 changed files with 6 additions and 5 deletions
3
main.py
3
main.py
|
|
@ -40,7 +40,8 @@ def main():
|
|||
compress(device=device,
|
||||
model_path=args.model_load_path,
|
||||
input_file=args.input_file,
|
||||
output_file=args.output_file
|
||||
output_file=args.output_file,
|
||||
context_length=args.context
|
||||
)
|
||||
|
||||
case _:
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
import contextlib
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import torch
|
||||
from pyae import ArithmeticEncoding
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.utils import reference_ae
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
context_length: int = 128,
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None
|
||||
):
|
||||
# NOTE Hardcoded context length
|
||||
context_length = 128
|
||||
|
||||
# Get input to compress
|
||||
print("Reading input")
|
||||
if input_file:
|
||||
|
|
|
|||
Reference in a new issue