feat (WIP): Compress
This commit is contained in:
parent
d0457b6571
commit
5c26a52e16
4 changed files with 70 additions and 8 deletions
|
|
@ -1,13 +1,22 @@
|
|||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import torch
|
||||
from pyae import ArithmeticEncoding
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
output_file: str,
|
||||
input_file: str | None = None
|
||||
device,
|
||||
model_path: str,
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None
|
||||
):
|
||||
# NOTE Hardcoded context length
|
||||
context_length = 128
|
||||
|
||||
# Get input to compress
|
||||
print("Reading input")
|
||||
if input_file:
|
||||
with open(input_file, "rb") as file:
|
||||
byte_data = file.read()
|
||||
|
|
@ -16,14 +25,56 @@ def compress(
|
|||
text = input()
|
||||
byte_data = text.encode('utf-8', errors='replace')
|
||||
|
||||
print("Converting to tensor")
|
||||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
print(tensor)
|
||||
|
||||
# Get model
|
||||
print("Loading model")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# TODO Feed to model for compression, store result
|
||||
return
|
||||
# Init AE
|
||||
print("Initializing AE")
|
||||
AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used
|
||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||
stage = None
|
||||
|
||||
# Compress
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
probabilities = probabilities.detach().cpu().numpy()
|
||||
|
||||
eps = 1e-10
|
||||
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||
probability_table = AE.get_probability_table(frequency_table)
|
||||
|
||||
stage = AE.process_stage(probability_table, stage_min, stage_max)
|
||||
stage_min, stage_max = stage[byte]
|
||||
|
||||
context.append(byte)
|
||||
|
||||
print("Getting encoded value")
|
||||
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||
print("Encoding in binary")
|
||||
binary_code, _ = AE.encode_binary(interval_min, interval_max)
|
||||
|
||||
# Pack
|
||||
bits = binary_code.split(".", maxsplit=1)[1]
|
||||
val = int(bits, 2) if len(bits) else 0
|
||||
out_bytes = val.to_bytes((len(bits) + 7) // 8, "big")
|
||||
|
||||
if output_file:
|
||||
print(f"Writing to {output_file}")
|
||||
with open(output_file, "wb") as file:
|
||||
file.write(out_bytes)
|
||||
else:
|
||||
print(out_bytes)
|
||||
|
||||
|
||||
def decompress():
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def train(
|
|||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
):
|
||||
batch_size = 2
|
||||
batch_size = 64
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
|
|
|
|||
Reference in a new issue