From 2c0e0c227877038a9a89363e71cb715199ce0041 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 23:24:34 +0100 Subject: [PATCH] WIP: Attempt at switching --- src/process.py | 66 ++++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/src/process.py b/src/process.py index dc5c479..e59defd 100644 --- a/src/process.py +++ b/src/process.py @@ -36,45 +36,49 @@ def compress( # Init AE print("Initializing AE") - AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used - stage_min, stage_max = Decimal(0), Decimal(1) - stage = None - # Compress - context = deque([0] * context_length, maxlen=context_length) - for byte in tqdm(tensor.tolist(), desc="Compressing"): - context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) + with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout: + enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout) - with torch.inference_mode(): - logits = model(context_tensor) - probabilities = torch.softmax(logits[0], dim=-1) - probabilities = probabilities.detach().cpu().numpy() + context = deque([0] * context_length, maxlen=context_length) - eps = 1e-10 - frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} - probability_table = AE.get_probability_table(frequency_table) + stage_min, stage_max = Decimal(0), Decimal(1) + stage = None - stage = AE.process_stage(probability_table, stage_min, stage_max) - stage_min, stage_max = stage[byte] + # Compress + for byte in tqdm(tensor.tolist(), desc="Compressing"): + context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) - context.append(byte) + with torch.inference_mode(): + logits = model(context_tensor) + probabilities = torch.softmax(logits[0], dim=-1) + probabilities = probabilities.detach().cpu().numpy() - print("Getting encoded value") - interval_min, interval_max, _ = AE.get_encoded_value(stage) - print("Encoding in binary") - binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) + eps = 1e-10 + frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities)) + probability_table = AE.get_probability_table(frequency_table) - # Pack - val = int(binary_code, 2) if len(binary_code) else 0 - out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") + enc.write(frequency_table, byte) - if output_file: - print(f"Writing to {output_file}") - with open(output_file, "w") as file: - file.write(f"{len(byte_data)}\n") - file.write(binary_code) # todo: temporary, decoding depends on binary string - else: - print(out_bytes) + context.append(byte) + + # print("Getting encoded value") + # interval_min, interval_max, _ = AE.get_encoded_value(stage) + # print("Encoding in binary") + # binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) + + # Pack + # val = int(binary_code, 2) if len(binary_code) else 0 + # out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") + + # if output_file: + # print(f"Writing to {output_file}") + # with open(output_file, "w") as file: + # file.write(f"{len(byte_data)}\n") + # file.write(binary_code) # todo: temporary, decoding depends on binary string + # else: + # print(out_bytes) def bits_to_number(bits: str) -> float: