fix: encoding binary now SUPAH fast

This commit is contained in:
Robin Meersman 2025-12-11 20:36:24 +01:00
parent 653b44804a
commit 77b80914e8
2 changed files with 44 additions and 35 deletions

View file

@ -62,7 +62,7 @@ def compress(
print("Getting encoded value")
interval_min, interval_max, _ = AE.get_encoded_value(stage)
print("Encoding in binary")
binary_code, _ = AE.encode_binary(interval_min, interval_max)
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
# Pack
bits = binary_code.split(".", maxsplit=1)[1]
@ -77,5 +77,31 @@ def compress(
print(out_bytes)
def decompress():
return NotImplementedError("Decompression is not implemented yet")
def decompress(
device,
model_path: str,
input_file: str,
output_file: str | None = None
):
context_length = 128
output = bytearray()
print("Reading in the data")
with open(input_file, "rb") as f:
bytes_data = f.read()
if len(bytes_data) == 0:
print("Input file is empty, nothing has to be done...")
return
print("Loading the model")
model = torch.load(model_path, weights_only=False)
model.to(device)
model.eval()
print("Decompressing")
ae = ArithmeticEncoding(frequency_table={0: 1})
stage_min, stage_max = Decimal(0), Decimal(1)
context = deque([0] * context_length, maxlen=context_length)

View file

@ -201,40 +201,23 @@ class CustomArithmeticEncoding:
float_interval_max: float
"""
code = []
k = 1
halves = [
[0.0, 1 / 2],
[1 / 2, 1.0]
]
found = False
next_n = 0.5
n = 0
i = 0
while i < 1024:
k += 1
i += 1
if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max:
break
if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max:
break
# left interval, insert 0
if float_interval_max < halves[0][1]:
code.append(0)
low = halves[0][0]
high = halves[0][1]
else:
while not found:
if n + next_n < float_interval_max:
code.append(1)
low = halves[1][0]
high = halves[1][1]
n += next_n
halves[0][0] = low
halves[0][1] = low + 1 / (1 << k)
halves[1][0] = halves[0][1]
halves[1][1] = high
if n >= float_interval_min:
found = True
else:
code.append(0)
return "0." + ''.join(map(str, code)), k
next_n /= 2
return ''.join(map(str, code))
def decode(self, encoded_msg, msg_length, probability_table):
"""
@ -358,8 +341,8 @@ def bin2float(bin_num):
if __name__ == "__main__":
coder = CustomArithmeticEncoding({})
low = 0.25
high = 0.5
low = 0.00324
high = 0.357
# slow_code = coder.encode_binary(low, high)
fast_code = coder.custom_binary_encoding(low, high)