feat: decompression

This commit is contained in:
Robin Meersman 2025-12-11 22:45:28 +01:00
parent 77b80914e8
commit eec3b2b1e6
2 changed files with 52 additions and 10 deletions

View file

@ -65,18 +65,38 @@ def compress(
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
# Pack
bits = binary_code.split(".", maxsplit=1)[1]
val = int(bits, 2) if len(bits) else 0
out_bytes = val.to_bytes((len(bits) + 7) // 8, "big")
val = int(binary_code, 2) if len(binary_code) else 0
out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
if output_file:
print(f"Writing to {output_file}")
with open(output_file, "wb") as file:
file.write(out_bytes)
with open(output_file, "w") as file:
file.write(f"{len(byte_data)}\n")
file.write(binary_code) # todo: temporary, decoding depends on binary string
else:
print(out_bytes)
def bits_to_number(bits: str) -> float:
n = 0
for i, bit in enumerate(bits, start=1):
n += int(bit) / (1 << i)
return n
def make_cumulative(probs):
cumulative = []
total = 0
for prob in probs:
low = total
high = total + prob
cumulative.append((low, high))
total = high
return cumulative
def decompress(
device,
model_path: str,
@ -84,10 +104,10 @@ def decompress(
output_file: str | None = None
):
context_length = 128
output = bytearray()
print("Reading in the data")
with open(input_file, "rb") as f:
with open(input_file, "r") as f:
length = int(f.readline())
bytes_data = f.read()
if len(bytes_data) == 0:
@ -100,8 +120,29 @@ def decompress(
model.eval()
print("Decompressing")
ae = ArithmeticEncoding(frequency_table={0: 1})
stage_min, stage_max = Decimal(0), Decimal(1)
context = deque([0] * context_length, maxlen=context_length)
output = bytearray()
x = bits_to_number(bytes_data)
for _ in range(length):
probs = model(context)
cumulative = make_cumulative(probs)
for symbol, (low, high) in enumerate(cumulative):
if low <= x < high:
break
output.append(symbol)
context.append(chr(symbol))
interval_low, interval_high = cumulative[symbol]
interval_width = interval_high - interval_low
x = (x - interval_low) / interval_width
if output_file is not None:
with open(output_file, "wb") as f:
f.write(output)
return
print(output.decode('utf-8', errors='replace'))

View file

@ -219,6 +219,7 @@ class CustomArithmeticEncoding:
return ''.join(map(str, code))
def decode(self, encoded_msg, msg_length, probability_table):
"""
Decodes a message from a floating-point number.