feat: faster encoding binary

This commit is contained in:
Robin Meersman 2025-12-11 15:22:28 +01:00
parent 5c26a52e16
commit 653b44804a

368
src/utils/custom_ae.py Normal file
View file

@ -0,0 +1,368 @@
from decimal import Decimal
class CustomArithmeticEncoding:
"""
ArithmeticEncoding is a class for building the arithmetic encoding.
"""
def __init__(self, frequency_table, save_stages=False):
"""
frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency.
save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large
"""
self.save_stages = save_stages
if (save_stages == True):
print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.")
self.probability_table = self.get_probability_table(frequency_table)
def get_probability_table(self, frequency_table):
"""
Calculates the probability table out of the frequency table.
frequency_table: A table of the term frequencies.
Returns the probability table.
"""
total_frequency = sum(list(frequency_table.values()))
probability_table = {}
for key, value in frequency_table.items():
probability_table[key] = value / total_frequency
return probability_table
def get_encoded_value(self, last_stage_probs):
"""
After encoding the entire message, this method returns the single value that represents the entire message.
last_stage_probs: A list of the probabilities in the last stage.
Returns the minimum and maximum probabilites in the last stage in addition to the value encoding the message.
"""
last_stage_probs = list(last_stage_probs.values())
last_stage_values = []
for sublist in last_stage_probs:
for element in sublist:
last_stage_values.append(element)
last_stage_min = min(last_stage_values)
last_stage_max = max(last_stage_values)
encoded_value = (last_stage_min + last_stage_max) / 2
return last_stage_min, last_stage_max, encoded_value
def process_stage(self, probability_table, stage_min, stage_max):
"""
Processing a stage in the encoding/decoding process.
probability_table: The probability table.
stage_min: The minumim probability of the current stage.
stage_max: The maximum probability of the current stage.
Returns the probabilities in the stage.
"""
stage_probs = {}
stage_domain = stage_max - stage_min
for term_idx in range(len(probability_table.items())):
term = list(probability_table.keys())[term_idx]
term_prob = Decimal(probability_table[term])
cum_prob = term_prob * stage_domain + stage_min
stage_probs[term] = [stage_min, cum_prob]
stage_min = cum_prob
return stage_probs
def encode(self, msg, probability_table):
"""
Encodes a message using arithmetic encoding.
msg: The message to be encoded.
probability_table: The probability table.
Returns the encoder, the floating-point value representing the encoded message, and the maximum and minimum values of the interval in which the floating-point value falls.
"""
msg = list(msg)
encoder = []
stage_min = Decimal(0.0)
stage_max = Decimal(1.0)
for msg_term_idx in range(len(msg)):
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
msg_term = msg[msg_term_idx]
stage_min = stage_probs[msg_term][0]
stage_max = stage_probs[msg_term][1]
if self.save_stages:
encoder.append(stage_probs)
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
if self.save_stages:
encoder.append(last_stage_probs)
interval_min_value, interval_max_value, encoded_msg = self.get_encoded_value(last_stage_probs)
return encoded_msg, encoder, interval_min_value, interval_max_value
def process_stage_binary(self, float_interval_min, float_interval_max, stage_min_bin, stage_max_bin):
"""
Processing a stage in the encoding/decoding process.
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
stage_min_bin: The minimum binary number in the current stage.
stage_max_bin: The maximum binary number in the current stage.
Returns the probabilities of the terms in this stage. There are only 2 terms.
"""
stage_mid_bin = stage_min_bin + "1"
stage_min_bin = stage_min_bin + "0"
stage_probs = {}
stage_probs[0] = [stage_min_bin, stage_mid_bin]
stage_probs[1] = [stage_mid_bin, stage_max_bin]
return stage_probs
def encode_binary(self, float_interval_min, float_interval_max):
"""
Calculates the binary code that represents the floating-point value that encodes the message.
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
Returns the binary code representing the encoded message.
"""
binary_encoder = []
binary_code = None
stage_min_bin = "0.0"
stage_max_bin = "1.0"
stage_probs = {}
stage_probs[0] = [stage_min_bin, "0.1"]
stage_probs[1] = ["0.1", stage_max_bin]
while True:
if float_interval_max < bin2float(stage_probs[0][1]):
stage_min_bin = stage_probs[0][0]
stage_max_bin = stage_probs[0][1]
else:
stage_min_bin = stage_probs[1][0]
stage_max_bin = stage_probs[1][1]
if self.save_stages:
binary_encoder.append(stage_probs)
stage_probs = self.process_stage_binary(float_interval_min,
float_interval_max,
stage_min_bin,
stage_max_bin)
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
if (bin2float(stage_probs[0][0]) >= float_interval_min) and (
bin2float(stage_probs[0][1]) < float_interval_max):
# The binary code is found.
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
# print("The binary code is : ", stage_probs[0][0])
binary_code = stage_probs[0][0]
break
elif (bin2float(stage_probs[1][0]) >= float_interval_min) and (
bin2float(stage_probs[1][1]) < float_interval_max):
# The binary code is found.
# print(stage_probs[1][0], bin2float(stage_probs[1][0]))
# print(stage_probs[1][1], bin2float(stage_probs[1][1]))
# print("The binary code is : ", stage_probs[1][0])
binary_code = stage_probs[1][0]
break
if self.save_stages:
binary_encoder.append(stage_probs)
return binary_code, binary_encoder
def custom_binary_encoding(self, float_interval_min, float_interval_max):
"""
Find the binary representation of the floating punt number which lies in
[float_interval_min, float_interval_max).
float_interval_min: float
float_interval_max: float
"""
code = []
k = 1
halves = [
[0.0, 1 / 2],
[1 / 2, 1.0]
]
i = 0
while i < 1024:
k += 1
i += 1
if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max:
break
if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max:
break
# left interval, insert 0
if float_interval_max < halves[0][1]:
code.append(0)
low = halves[0][0]
high = halves[0][1]
else:
code.append(1)
low = halves[1][0]
high = halves[1][1]
halves[0][0] = low
halves[0][1] = low + 1 / (1 << k)
halves[1][0] = halves[0][1]
halves[1][1] = high
return "0." + ''.join(map(str, code)), k
def decode(self, encoded_msg, msg_length, probability_table):
"""
Decodes a message from a floating-point number.
encoded_msg: The floating-point value that encodes the message.
msg_length: Length of the message.
probability_table: The probability table.
Returns the decoded message.
"""
decoder = []
decoded_msg = []
stage_min = Decimal(0.0)
stage_max = Decimal(1.0)
for idx in range(msg_length):
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
for msg_term, value in stage_probs.items():
if encoded_msg >= value[0] and encoded_msg <= value[1]:
break
decoded_msg.append(msg_term)
stage_min = stage_probs[msg_term][0]
stage_max = stage_probs[msg_term][1]
if self.save_stages:
decoder.append(stage_probs)
if self.save_stages:
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
decoder.append(last_stage_probs)
return decoded_msg, decoder
def float2bin(float_num, num_bits=None):
"""
Converts a floating-point number into binary.
float_num: The floating-point number.
num_bits: The number of bits expected in the result. If None, then the number of bits depends on the number.
Returns the binary representation of the number.
"""
float_num = str(float_num)
if float_num.find(".") == -1:
# No decimals in the floating-point number.
integers = float_num
decimals = ""
else:
integers, decimals = float_num.split(".")
decimals = "0." + decimals
decimals = Decimal(decimals)
integers = int(integers)
result = ""
num_used_bits = 0
while True:
mul = decimals * 2
int_part = int(mul)
result = result + str(int_part)
num_used_bits = num_used_bits + 1
decimals = mul - int(mul)
if type(num_bits) is type(None):
if decimals == 0:
break
elif num_used_bits >= num_bits:
break
if type(num_bits) is type(None):
pass
elif len(result) < num_bits:
num_remaining_bits = num_bits - len(result)
result = result + "0" * num_remaining_bits
integers_bin = bin(integers)[2:]
result = str(integers_bin) + "." + str(result)
return result
def bin2float(bin_num):
"""
Converts a binary number to a floating-point number.
bin_num: The binary number as a string.
Returns the floating-point representation.
"""
if bin_num.find(".") == -1:
# No decimals in the binary number.
integers = bin_num
decimals = ""
else:
integers, decimals = bin_num.split(".")
result = Decimal(0.0)
# Working with integers.
for idx, bit in enumerate(integers):
if bit == "0":
continue
mul = 2 ** idx
result = result + Decimal(mul)
# Working with decimals.
for idx, bit in enumerate(decimals):
if bit == "0":
continue
mul = Decimal(1.0) / Decimal((2 ** (idx + 1)))
result = result + mul
return result
if __name__ == "__main__":
coder = CustomArithmeticEncoding({})
low = 0.25
high = 0.5
# slow_code = coder.encode_binary(low, high)
fast_code = coder.custom_binary_encoding(low, high)
# print(slow_code)
print(fast_code)