WIP: Attempt at switching

This commit is contained in:
Tibo De Peuter 2025-12-11 23:24:34 +01:00
parent 817c16bde4
commit 2c0e0c2278
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -36,45 +36,49 @@ def compress(
# Init AE # Init AE
print("Initializing AE") print("Initializing AE")
AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used
stage_min, stage_max = Decimal(0), Decimal(1)
stage = None
# Compress with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
context = deque([0] * context_length, maxlen=context_length) enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
with torch.inference_mode(): context = deque([0] * context_length, maxlen=context_length)
logits = model(context_tensor)
probabilities = torch.softmax(logits[0], dim=-1)
probabilities = probabilities.detach().cpu().numpy()
eps = 1e-10 stage_min, stage_max = Decimal(0), Decimal(1)
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} stage = None
probability_table = AE.get_probability_table(frequency_table)
stage = AE.process_stage(probability_table, stage_min, stage_max) # Compress
stage_min, stage_max = stage[byte] for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
context.append(byte) with torch.inference_mode():
logits = model(context_tensor)
probabilities = torch.softmax(logits[0], dim=-1)
probabilities = probabilities.detach().cpu().numpy()
print("Getting encoded value") eps = 1e-10
interval_min, interval_max, _ = AE.get_encoded_value(stage) frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
print("Encoding in binary") probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities))
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) probability_table = AE.get_probability_table(frequency_table)
# Pack enc.write(frequency_table, byte)
val = int(binary_code, 2) if len(binary_code) else 0
out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
if output_file: context.append(byte)
print(f"Writing to {output_file}")
with open(output_file, "w") as file: # print("Getting encoded value")
file.write(f"{len(byte_data)}\n") # interval_min, interval_max, _ = AE.get_encoded_value(stage)
file.write(binary_code) # todo: temporary, decoding depends on binary string # print("Encoding in binary")
else: # binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
print(out_bytes)
# Pack
# val = int(binary_code, 2) if len(binary_code) else 0
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
# if output_file:
# print(f"Writing to {output_file}")
# with open(output_file, "w") as file:
# file.write(f"{len(byte_data)}\n")
# file.write(binary_code) # todo: temporary, decoding depends on binary string
# else:
# print(out_bytes)
def bits_to_number(bits: str) -> float: def bits_to_number(bits: str) -> float: