From 653b44804ad3059fcef03aba8c51ad52723c7db2 Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Thu, 11 Dec 2025 15:22:28 +0100 Subject: [PATCH] feat: faster encoding binary --- src/utils/custom_ae.py | 368 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 368 insertions(+) create mode 100644 src/utils/custom_ae.py diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py new file mode 100644 index 0000000..96e1242 --- /dev/null +++ b/src/utils/custom_ae.py @@ -0,0 +1,368 @@ +from decimal import Decimal + + +class CustomArithmeticEncoding: + """ + ArithmeticEncoding is a class for building the arithmetic encoding. + """ + + def __init__(self, frequency_table, save_stages=False): + """ + frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency. + save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large + """ + + self.save_stages = save_stages + if (save_stages == True): + print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.") + + self.probability_table = self.get_probability_table(frequency_table) + + def get_probability_table(self, frequency_table): + """ + Calculates the probability table out of the frequency table. + + frequency_table: A table of the term frequencies. + + Returns the probability table. + """ + total_frequency = sum(list(frequency_table.values())) + + probability_table = {} + for key, value in frequency_table.items(): + probability_table[key] = value / total_frequency + + return probability_table + + def get_encoded_value(self, last_stage_probs): + """ + After encoding the entire message, this method returns the single value that represents the entire message. + + last_stage_probs: A list of the probabilities in the last stage. + + Returns the minimum and maximum probabilites in the last stage in addition to the value encoding the message. + """ + last_stage_probs = list(last_stage_probs.values()) + last_stage_values = [] + for sublist in last_stage_probs: + for element in sublist: + last_stage_values.append(element) + + last_stage_min = min(last_stage_values) + last_stage_max = max(last_stage_values) + encoded_value = (last_stage_min + last_stage_max) / 2 + + return last_stage_min, last_stage_max, encoded_value + + def process_stage(self, probability_table, stage_min, stage_max): + """ + Processing a stage in the encoding/decoding process. + + probability_table: The probability table. + stage_min: The minumim probability of the current stage. + stage_max: The maximum probability of the current stage. + + Returns the probabilities in the stage. + """ + + stage_probs = {} + stage_domain = stage_max - stage_min + for term_idx in range(len(probability_table.items())): + term = list(probability_table.keys())[term_idx] + term_prob = Decimal(probability_table[term]) + cum_prob = term_prob * stage_domain + stage_min + stage_probs[term] = [stage_min, cum_prob] + stage_min = cum_prob + return stage_probs + + def encode(self, msg, probability_table): + """ + Encodes a message using arithmetic encoding. + + msg: The message to be encoded. + probability_table: The probability table. + + Returns the encoder, the floating-point value representing the encoded message, and the maximum and minimum values of the interval in which the floating-point value falls. + """ + + msg = list(msg) + + encoder = [] + + stage_min = Decimal(0.0) + stage_max = Decimal(1.0) + + for msg_term_idx in range(len(msg)): + stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + msg_term = msg[msg_term_idx] + stage_min = stage_probs[msg_term][0] + stage_max = stage_probs[msg_term][1] + + if self.save_stages: + encoder.append(stage_probs) + + last_stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + if self.save_stages: + encoder.append(last_stage_probs) + + interval_min_value, interval_max_value, encoded_msg = self.get_encoded_value(last_stage_probs) + + return encoded_msg, encoder, interval_min_value, interval_max_value + + def process_stage_binary(self, float_interval_min, float_interval_max, stage_min_bin, stage_max_bin): + """ + Processing a stage in the encoding/decoding process. + + float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located. + float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located. + stage_min_bin: The minimum binary number in the current stage. + stage_max_bin: The maximum binary number in the current stage. + + Returns the probabilities of the terms in this stage. There are only 2 terms. + """ + + stage_mid_bin = stage_min_bin + "1" + stage_min_bin = stage_min_bin + "0" + + stage_probs = {} + stage_probs[0] = [stage_min_bin, stage_mid_bin] + stage_probs[1] = [stage_mid_bin, stage_max_bin] + + return stage_probs + + def encode_binary(self, float_interval_min, float_interval_max): + """ + Calculates the binary code that represents the floating-point value that encodes the message. + + float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located. + float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located. + + Returns the binary code representing the encoded message. + """ + + binary_encoder = [] + binary_code = None + + stage_min_bin = "0.0" + stage_max_bin = "1.0" + + stage_probs = {} + stage_probs[0] = [stage_min_bin, "0.1"] + stage_probs[1] = ["0.1", stage_max_bin] + + while True: + if float_interval_max < bin2float(stage_probs[0][1]): + stage_min_bin = stage_probs[0][0] + stage_max_bin = stage_probs[0][1] + else: + stage_min_bin = stage_probs[1][0] + stage_max_bin = stage_probs[1][1] + + if self.save_stages: + binary_encoder.append(stage_probs) + + stage_probs = self.process_stage_binary(float_interval_min, + float_interval_max, + stage_min_bin, + stage_max_bin) + + # print(stage_probs[0][0], bin2float(stage_probs[0][0])) + # print(stage_probs[0][1], bin2float(stage_probs[0][1])) + if (bin2float(stage_probs[0][0]) >= float_interval_min) and ( + bin2float(stage_probs[0][1]) < float_interval_max): + # The binary code is found. + # print(stage_probs[0][0], bin2float(stage_probs[0][0])) + # print(stage_probs[0][1], bin2float(stage_probs[0][1])) + # print("The binary code is : ", stage_probs[0][0]) + binary_code = stage_probs[0][0] + break + elif (bin2float(stage_probs[1][0]) >= float_interval_min) and ( + bin2float(stage_probs[1][1]) < float_interval_max): + # The binary code is found. + # print(stage_probs[1][0], bin2float(stage_probs[1][0])) + # print(stage_probs[1][1], bin2float(stage_probs[1][1])) + # print("The binary code is : ", stage_probs[1][0]) + binary_code = stage_probs[1][0] + break + + if self.save_stages: + binary_encoder.append(stage_probs) + + return binary_code, binary_encoder + + def custom_binary_encoding(self, float_interval_min, float_interval_max): + """ + Find the binary representation of the floating punt number which lies in + [float_interval_min, float_interval_max). + + float_interval_min: float + float_interval_max: float + """ + code = [] + k = 1 + halves = [ + [0.0, 1 / 2], + [1 / 2, 1.0] + ] + + i = 0 + + while i < 1024: + k += 1 + i += 1 + + if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max: + break + if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max: + break + + # left interval, insert 0 + if float_interval_max < halves[0][1]: + code.append(0) + low = halves[0][0] + high = halves[0][1] + + else: + code.append(1) + low = halves[1][0] + high = halves[1][1] + + halves[0][0] = low + halves[0][1] = low + 1 / (1 << k) + halves[1][0] = halves[0][1] + halves[1][1] = high + + return "0." + ''.join(map(str, code)), k + + def decode(self, encoded_msg, msg_length, probability_table): + """ + Decodes a message from a floating-point number. + + encoded_msg: The floating-point value that encodes the message. + msg_length: Length of the message. + probability_table: The probability table. + + Returns the decoded message. + """ + + decoder = [] + + decoded_msg = [] + + stage_min = Decimal(0.0) + stage_max = Decimal(1.0) + + for idx in range(msg_length): + stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + for msg_term, value in stage_probs.items(): + if encoded_msg >= value[0] and encoded_msg <= value[1]: + break + + decoded_msg.append(msg_term) + + stage_min = stage_probs[msg_term][0] + stage_max = stage_probs[msg_term][1] + + if self.save_stages: + decoder.append(stage_probs) + + if self.save_stages: + last_stage_probs = self.process_stage(probability_table, stage_min, stage_max) + decoder.append(last_stage_probs) + + return decoded_msg, decoder + + +def float2bin(float_num, num_bits=None): + """ + Converts a floating-point number into binary. + + float_num: The floating-point number. + num_bits: The number of bits expected in the result. If None, then the number of bits depends on the number. + + Returns the binary representation of the number. + """ + + float_num = str(float_num) + if float_num.find(".") == -1: + # No decimals in the floating-point number. + integers = float_num + decimals = "" + else: + integers, decimals = float_num.split(".") + decimals = "0." + decimals + decimals = Decimal(decimals) + integers = int(integers) + + result = "" + num_used_bits = 0 + while True: + mul = decimals * 2 + int_part = int(mul) + result = result + str(int_part) + num_used_bits = num_used_bits + 1 + + decimals = mul - int(mul) + if type(num_bits) is type(None): + if decimals == 0: + break + elif num_used_bits >= num_bits: + break + if type(num_bits) is type(None): + pass + elif len(result) < num_bits: + num_remaining_bits = num_bits - len(result) + result = result + "0" * num_remaining_bits + + integers_bin = bin(integers)[2:] + result = str(integers_bin) + "." + str(result) + return result + + +def bin2float(bin_num): + """ + Converts a binary number to a floating-point number. + + bin_num: The binary number as a string. + + Returns the floating-point representation. + """ + + if bin_num.find(".") == -1: + # No decimals in the binary number. + integers = bin_num + decimals = "" + else: + integers, decimals = bin_num.split(".") + result = Decimal(0.0) + + # Working with integers. + for idx, bit in enumerate(integers): + if bit == "0": + continue + mul = 2 ** idx + result = result + Decimal(mul) + + # Working with decimals. + for idx, bit in enumerate(decimals): + if bit == "0": + continue + mul = Decimal(1.0) / Decimal((2 ** (idx + 1))) + result = result + mul + return result + + +if __name__ == "__main__": + coder = CustomArithmeticEncoding({}) + + low = 0.25 + high = 0.5 + + # slow_code = coder.encode_binary(low, high) + fast_code = coder.custom_binary_encoding(low, high) + + # print(slow_code) + print(fast_code)