fix: No hardcoding context len
This commit is contained in:
parent
4c8d603092
commit
fc75ab51b0
2 changed files with 6 additions and 5 deletions
3
main.py
3
main.py
|
|
@ -40,7 +40,8 @@ def main():
|
||||||
compress(device=device,
|
compress(device=device,
|
||||||
model_path=args.model_load_path,
|
model_path=args.model_load_path,
|
||||||
input_file=args.input_file,
|
input_file=args.input_file,
|
||||||
output_file=args.output_file
|
output_file=args.output_file,
|
||||||
|
context_length=args.context
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
|
import contextlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pyae import ArithmeticEncoding
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.utils import reference_ae
|
||||||
|
|
||||||
|
|
||||||
def compress(
|
def compress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
context_length: int = 128,
|
||||||
input_file: str | None = None,
|
input_file: str | None = None,
|
||||||
output_file: str | None = None
|
output_file: str | None = None
|
||||||
):
|
):
|
||||||
# NOTE Hardcoded context length
|
|
||||||
context_length = 128
|
|
||||||
|
|
||||||
# Get input to compress
|
# Get input to compress
|
||||||
print("Reading input")
|
print("Reading input")
|
||||||
if input_file:
|
if input_file:
|
||||||
|
|
|
||||||
Reference in a new issue