From eec3b2b1e624c7e2f9b8e756ee99f3b1f2a62458 Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Thu, 11 Dec 2025 22:45:28 +0100 Subject: [PATCH] feat: decompression --- src/process.py | 61 +++++++++++++++++++++++++++++++++++------- src/utils/custom_ae.py | 1 + 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/process.py b/src/process.py index d77f30d..a681960 100644 --- a/src/process.py +++ b/src/process.py @@ -65,18 +65,38 @@ def compress( binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) # Pack - bits = binary_code.split(".", maxsplit=1)[1] - val = int(bits, 2) if len(bits) else 0 - out_bytes = val.to_bytes((len(bits) + 7) // 8, "big") + val = int(binary_code, 2) if len(binary_code) else 0 + out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") if output_file: print(f"Writing to {output_file}") - with open(output_file, "wb") as file: - file.write(out_bytes) + with open(output_file, "w") as file: + file.write(f"{len(byte_data)}\n") + file.write(binary_code) # todo: temporary, decoding depends on binary string else: print(out_bytes) +def bits_to_number(bits: str) -> float: + n = 0 + for i, bit in enumerate(bits, start=1): + n += int(bit) / (1 << i) + return n + + +def make_cumulative(probs): + cumulative = [] + + total = 0 + + for prob in probs: + low = total + high = total + prob + cumulative.append((low, high)) + total = high + return cumulative + + def decompress( device, model_path: str, @@ -84,10 +104,10 @@ def decompress( output_file: str | None = None ): context_length = 128 - output = bytearray() print("Reading in the data") - with open(input_file, "rb") as f: + with open(input_file, "r") as f: + length = int(f.readline()) bytes_data = f.read() if len(bytes_data) == 0: @@ -100,8 +120,29 @@ def decompress( model.eval() print("Decompressing") - ae = ArithmeticEncoding(frequency_table={0: 1}) - stage_min, stage_max = Decimal(0), Decimal(1) - context = deque([0] * context_length, maxlen=context_length) + output = bytearray() + x = bits_to_number(bytes_data) + + for _ in range(length): + probs = model(context) + cumulative = make_cumulative(probs) + + for symbol, (low, high) in enumerate(cumulative): + if low <= x < high: + break + + output.append(symbol) + context.append(chr(symbol)) + + interval_low, interval_high = cumulative[symbol] + interval_width = interval_high - interval_low + x = (x - interval_low) / interval_width + + if output_file is not None: + with open(output_file, "wb") as f: + f.write(output) + return + + print(output.decode('utf-8', errors='replace')) diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py index 006e53b..86f3548 100644 --- a/src/utils/custom_ae.py +++ b/src/utils/custom_ae.py @@ -219,6 +219,7 @@ class CustomArithmeticEncoding: return ''.join(map(str, code)) + def decode(self, encoded_msg, msg_length, probability_table): """ Decodes a message from a floating-point number.