Merge branch 'main' of github.ugent.be:ML/neural-compression
1
.gitignore
vendored
|
|
@ -3,3 +3,4 @@ __pycache__
|
||||||
data/
|
data/
|
||||||
saved_models/
|
saved_models/
|
||||||
output/
|
output/
|
||||||
|
.DS_Store
|
||||||
|
|
|
||||||
|
Before Width: | Height: | Size: 26 KiB |
|
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 28 KiB |
|
Before Width: | Height: | Size: 19 KiB |
|
Before Width: | Height: | Size: 19 KiB |
BIN
graphs/autoencoder_enwik9_execution_time.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
|
Before Width: | Height: | Size: 29 KiB |
|
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 29 KiB |
|
Before Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 23 KiB |
BIN
graphs/autoencoder_genome_execution_time.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
graphs/autoencoder_loss.png
Normal file
|
After Width: | Height: | Size: 31 KiB |
|
Before Width: | Height: | Size: 26 KiB |
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 32 KiB |
|
Before Width: | Height: | Size: 17 KiB |
|
Before Width: | Height: | Size: 17 KiB |
BIN
graphs/cnn_enwik9_execution_time.png
Normal file
|
After Width: | Height: | Size: 41 KiB |
BIN
graphs/cnn_enwik9_extrapolated_execution_time.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
|
Before Width: | Height: | Size: 30 KiB |
|
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 18 KiB |
BIN
graphs/cnn_genome_execution_time.png
Normal file
|
After Width: | Height: | Size: 45 KiB |
BIN
graphs/cnn_genome_extrapolated_execution_time.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
graphs/cnn_loss.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
125
make_graphs.py
|
|
@ -1,5 +1,6 @@
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -9,36 +10,99 @@ if __name__ == "__main__":
|
||||||
for dataset_type in df["dataset_type"].unique():
|
for dataset_type in df["dataset_type"].unique():
|
||||||
for model_type in df["model_type"].unique():
|
for model_type in df["model_type"].unique():
|
||||||
dataset_df = df[df["dataset_type"] == dataset_type]
|
dataset_df = df[df["dataset_type"] == dataset_type]
|
||||||
model_df = dataset_df[dataset_df["model_type"] == model_type]
|
model_df = dataset_df[dataset_df["model_type"] == model_type].copy()
|
||||||
|
|
||||||
# execution time
|
# execution time
|
||||||
plt.figure()
|
plt.figure()
|
||||||
grouped = model_df.groupby("context_length")["compression_time"].mean() / 1e9
|
model_df["original_file_size_mb"] = model_df["original_file_size"] / 1e6
|
||||||
labels = grouped.index.astype(str) # "128", "256"
|
model_df["compression_time_s"] = model_df["compression_time"] / 1e9
|
||||||
x = np.arange(len(labels)) # [0, 1]
|
model_df["decompression_time_s"] = model_df["decompression_time"] / 1e9
|
||||||
|
# compression
|
||||||
plt.bar(x, grouped.values, width=0.6)
|
sns.lineplot(
|
||||||
plt.title(f"{model_type.capitalize()} mean compression time")
|
data=model_df,
|
||||||
plt.xticks(x, labels)
|
x="original_file_size_mb",
|
||||||
plt.xlabel("Context length")
|
y="compression_time_s",
|
||||||
plt.ylabel("Mean compression time [s]")
|
hue="context_length",
|
||||||
|
palette="Set1",
|
||||||
|
markers=True,
|
||||||
|
legend="brief",
|
||||||
|
linestyle="-"
|
||||||
|
)
|
||||||
|
# decompression
|
||||||
|
sns.lineplot(
|
||||||
|
data=model_df,
|
||||||
|
x="original_file_size_mb",
|
||||||
|
y="decompression_time_s",
|
||||||
|
hue="context_length",
|
||||||
|
palette="Set1",
|
||||||
|
markers=True,
|
||||||
|
legend=False,
|
||||||
|
linestyle="--"
|
||||||
|
)
|
||||||
|
plt.title(f"{model_type.capitalize()} compression and decompression time: {dataset_type}")
|
||||||
|
plt.xlabel("file size [MB]")
|
||||||
|
plt.ylabel("Time [s]")
|
||||||
|
plt.yscale("log")
|
||||||
|
plt.legend([f"{style}, {c_type}" for style, c_type in zip(["Solid", "Dashed"], ["compression", "decompression"])])
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"./graphs/{model_type}_{dataset_type}_compression_time.png")
|
plt.savefig(f"./graphs/{model_type}_{dataset_type}_execution_time.png")
|
||||||
|
|
||||||
|
# compression ratio
|
||||||
plt.figure()
|
plt.figure()
|
||||||
grouped = model_df.groupby("context_length")["decompression_time"].mean() / 1e9
|
c256 = model_df[model_df["context_length"] == 256]
|
||||||
labels = grouped.index.astype(str) # "128", "256"
|
c128 = model_df[model_df["context_length"] == 128]
|
||||||
x = np.arange(len(labels)) # [0, 1]
|
|
||||||
|
|
||||||
plt.bar(x, grouped.values, width=0.6)
|
plt.plot(c256["original_file_size"] / 1e6, c256["compressed_file_size"] / 1e6, label="256")
|
||||||
plt.title(f"{model_type.capitalize()} mean decompression time")
|
plt.plot(c128["original_file_size"] / 1e6, c128["compressed_file_size"] / 1e6, label="128")
|
||||||
plt.xticks(x, labels)
|
plt.title(f"{model_type.capitalize()} compressed file evolution: {dataset_type}")
|
||||||
plt.xlabel("Context length")
|
plt.xlabel("Original file size [MB]")
|
||||||
plt.ylabel("Mean decompression time [s]")
|
plt.ylabel("Compressed file size [MB]")
|
||||||
plt.tight_layout()
|
plt.legend()
|
||||||
plt.savefig(f"./graphs/{model_type}_{dataset_type}_decompression_time.png")
|
plt.savefig(f"./graphs/{model_type}_{dataset_type}_compression_ratio.png")
|
||||||
|
|
||||||
|
|
||||||
|
# if model_type == "cnn":
|
||||||
|
# import numpy as np
|
||||||
|
#
|
||||||
|
# plt.figure()
|
||||||
|
# for length, linestyle in [(128, '-'), (256, '--')]:
|
||||||
|
# # extrapolate execution time to larger files
|
||||||
|
# x = model_df[model_df["context_length"] == length]["original_file_size"] / 1e6
|
||||||
|
# y = model_df[model_df["context_length"] == length]["compression_time"]
|
||||||
|
# y_decom = model_df[model_df["context_length"] == length]["decompression_time"]
|
||||||
|
#
|
||||||
|
# b1, loga1 = np.polyfit(x, np.log(y), 1)
|
||||||
|
# b2, loga2 = np.polyfit(x, np.log(y_decom), 1)
|
||||||
|
#
|
||||||
|
# x_comp = np.linspace(0, 40, 1000)
|
||||||
|
# x_decomp = np.linspace(0, 40, 1000)
|
||||||
|
# a1 = np.exp(loga1)
|
||||||
|
# a2 = np.exp(loga2)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# plt.plot(
|
||||||
|
# x_comp, a1 * np.exp(x_comp),
|
||||||
|
# label=f"{length} compression",
|
||||||
|
# linestyle=linestyle
|
||||||
|
# )
|
||||||
|
# plt.plot(
|
||||||
|
# x_decomp, a2 * np.exp(x_decomp),
|
||||||
|
# label=f"{length} decompression",
|
||||||
|
# linestyle=linestyle
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# plt.legend()
|
||||||
|
# plt.title(f"Extrapolated execution time for CNN compression and decompression")
|
||||||
|
# plt.xlabel("File size [MB]")
|
||||||
|
# plt.ylabel("Time [s]")
|
||||||
|
# plt.tight_layout()
|
||||||
|
# plt.savefig(f"./graphs/{model_type}_{dataset_type}_extrapolated_execution_time.png")
|
||||||
|
|
||||||
|
for model_type in df["model_type"].unique():
|
||||||
|
model_df = df[df["model_type"] == model_type]
|
||||||
|
|
||||||
# loss
|
|
||||||
plt.figure(figsize=(10, 4))
|
plt.figure(figsize=(10, 4))
|
||||||
bar_height = 0.25
|
bar_height = 0.25
|
||||||
files = model_df["input_file_name"].unique()
|
files = model_df["input_file_name"].unique()
|
||||||
|
|
@ -60,22 +124,9 @@ if __name__ == "__main__":
|
||||||
label="128"
|
label="128"
|
||||||
)
|
)
|
||||||
plt.yticks(y, files, rotation=45, ha="right")
|
plt.yticks(y, files, rotation=45, ha="right")
|
||||||
plt.title(f"{model_type.capitalize()} MSE loss for different context lengths")
|
plt.title(f"MSE loss for different context lengths")
|
||||||
plt.xlabel("MSE loss")
|
plt.xlabel("MSE loss")
|
||||||
plt.ylabel("Filename")
|
plt.ylabel("Filename")
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"./graphs/{model_type}_{dataset_type}_accuracy.png")
|
plt.savefig(f"./graphs/{model_type}_loss.png")
|
||||||
|
|
||||||
# compression ratio
|
|
||||||
plt.figure()
|
|
||||||
c256 = model_df[model_df["context_length"] == 256]
|
|
||||||
c128 = model_df[model_df["context_length"] == 128]
|
|
||||||
|
|
||||||
plt.plot(c256["original_file_size"] / 1e6, c256["compressed_file_size"] / 1e6, label="256")
|
|
||||||
plt.plot(c128["original_file_size"] / 1e6, c128["compressed_file_size"] / 1e6, label="128")
|
|
||||||
plt.title(f"{model_type.capitalize()} compressed file evolution")
|
|
||||||
plt.xlabel("Original file size [MB]")
|
|
||||||
plt.ylabel("Compressed file size [MB]")
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig(f"./graphs/{model_type}_{dataset_type}_compression_ratio.png")
|
|
||||||
|
|
|
||||||
19
models/README.md
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Meta information about training
|
||||||
|
|
||||||
|
The trained models that are saved here follow the following naming convention:
|
||||||
|
|
||||||
|
* Optuna intermediate model: `[model]-[dataset]-[context].pt`
|
||||||
|
* Fully trained model: `[model]-[dataset]-full-[context].pt`
|
||||||
|
|
||||||
|
The following parameters were used:
|
||||||
|
|
||||||
|
* training size: 2048 for optuna, 209715 for full training
|
||||||
|
* context sizes: {128, 256}
|
||||||
|
|
||||||
|
The models were trained with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python ./results/[cnn,autoencoder] train --method [full,optuna] \
|
||||||
|
--data-root ./data --dataset [genome,enwik9] --context [128,256] --size [2048,209715] \
|
||||||
|
--model-save-path ./models/<name> --model-load-path <path if full trainig, output from optuna>
|
||||||
|
```
|
||||||
|
|
@ -9,6 +9,8 @@ dependencies = [
|
||||||
"fsspec==2024.9.0",
|
"fsspec==2024.9.0",
|
||||||
"lorem>=0.1.1",
|
"lorem>=0.1.1",
|
||||||
"arithmeticencodingpython",
|
"arithmeticencodingpython",
|
||||||
|
"pandas-stubs~=2.3.3",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
||||||
|
|
@ -1,352 +0,0 @@
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
|
|
||||||
class CustomArithmeticEncoding:
|
|
||||||
"""
|
|
||||||
ArithmeticEncoding is a class for building the arithmetic encoding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, frequency_table, save_stages=False):
|
|
||||||
"""
|
|
||||||
frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency.
|
|
||||||
save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.save_stages = save_stages
|
|
||||||
if (save_stages == True):
|
|
||||||
print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.")
|
|
||||||
|
|
||||||
self.probability_table = self.get_probability_table(frequency_table)
|
|
||||||
|
|
||||||
def get_probability_table(self, frequency_table):
|
|
||||||
"""
|
|
||||||
Calculates the probability table out of the frequency table.
|
|
||||||
|
|
||||||
frequency_table: A table of the term frequencies.
|
|
||||||
|
|
||||||
Returns the probability table.
|
|
||||||
"""
|
|
||||||
total_frequency = sum(list(frequency_table.values()))
|
|
||||||
|
|
||||||
probability_table = {}
|
|
||||||
for key, value in frequency_table.items():
|
|
||||||
probability_table[key] = value / total_frequency
|
|
||||||
|
|
||||||
return probability_table
|
|
||||||
|
|
||||||
def get_encoded_value(self, last_stage_probs):
|
|
||||||
"""
|
|
||||||
After encoding the entire message, this method returns the single value that represents the entire message.
|
|
||||||
|
|
||||||
last_stage_probs: A list of the probabilities in the last stage.
|
|
||||||
|
|
||||||
Returns the minimum and maximum probabilites in the last stage in addition to the value encoding the message.
|
|
||||||
"""
|
|
||||||
last_stage_probs = list(last_stage_probs.values())
|
|
||||||
last_stage_values = []
|
|
||||||
for sublist in last_stage_probs:
|
|
||||||
for element in sublist:
|
|
||||||
last_stage_values.append(element)
|
|
||||||
|
|
||||||
last_stage_min = min(last_stage_values)
|
|
||||||
last_stage_max = max(last_stage_values)
|
|
||||||
encoded_value = (last_stage_min + last_stage_max) / 2
|
|
||||||
|
|
||||||
return last_stage_min, last_stage_max, encoded_value
|
|
||||||
|
|
||||||
def process_stage(self, probability_table, stage_min, stage_max):
|
|
||||||
"""
|
|
||||||
Processing a stage in the encoding/decoding process.
|
|
||||||
|
|
||||||
probability_table: The probability table.
|
|
||||||
stage_min: The minumim probability of the current stage.
|
|
||||||
stage_max: The maximum probability of the current stage.
|
|
||||||
|
|
||||||
Returns the probabilities in the stage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stage_probs = {}
|
|
||||||
stage_domain = stage_max - stage_min
|
|
||||||
for term_idx in range(len(probability_table.items())):
|
|
||||||
term = list(probability_table.keys())[term_idx]
|
|
||||||
term_prob = Decimal(probability_table[term])
|
|
||||||
cum_prob = term_prob * stage_domain + stage_min
|
|
||||||
stage_probs[term] = [stage_min, cum_prob]
|
|
||||||
stage_min = cum_prob
|
|
||||||
return stage_probs
|
|
||||||
|
|
||||||
def encode(self, msg, probability_table):
|
|
||||||
"""
|
|
||||||
Encodes a message using arithmetic encoding.
|
|
||||||
|
|
||||||
msg: The message to be encoded.
|
|
||||||
probability_table: The probability table.
|
|
||||||
|
|
||||||
Returns the encoder, the floating-point value representing the encoded message, and the maximum and minimum values of the interval in which the floating-point value falls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
msg = list(msg)
|
|
||||||
|
|
||||||
encoder = []
|
|
||||||
|
|
||||||
stage_min = Decimal(0.0)
|
|
||||||
stage_max = Decimal(1.0)
|
|
||||||
|
|
||||||
for msg_term_idx in range(len(msg)):
|
|
||||||
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
|
|
||||||
|
|
||||||
msg_term = msg[msg_term_idx]
|
|
||||||
stage_min = stage_probs[msg_term][0]
|
|
||||||
stage_max = stage_probs[msg_term][1]
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
encoder.append(stage_probs)
|
|
||||||
|
|
||||||
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
encoder.append(last_stage_probs)
|
|
||||||
|
|
||||||
interval_min_value, interval_max_value, encoded_msg = self.get_encoded_value(last_stage_probs)
|
|
||||||
|
|
||||||
return encoded_msg, encoder, interval_min_value, interval_max_value
|
|
||||||
|
|
||||||
def process_stage_binary(self, float_interval_min, float_interval_max, stage_min_bin, stage_max_bin):
|
|
||||||
"""
|
|
||||||
Processing a stage in the encoding/decoding process.
|
|
||||||
|
|
||||||
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
|
|
||||||
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
|
|
||||||
stage_min_bin: The minimum binary number in the current stage.
|
|
||||||
stage_max_bin: The maximum binary number in the current stage.
|
|
||||||
|
|
||||||
Returns the probabilities of the terms in this stage. There are only 2 terms.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stage_mid_bin = stage_min_bin + "1"
|
|
||||||
stage_min_bin = stage_min_bin + "0"
|
|
||||||
|
|
||||||
stage_probs = {}
|
|
||||||
stage_probs[0] = [stage_min_bin, stage_mid_bin]
|
|
||||||
stage_probs[1] = [stage_mid_bin, stage_max_bin]
|
|
||||||
|
|
||||||
return stage_probs
|
|
||||||
|
|
||||||
def encode_binary(self, float_interval_min, float_interval_max):
|
|
||||||
"""
|
|
||||||
Calculates the binary code that represents the floating-point value that encodes the message.
|
|
||||||
|
|
||||||
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
|
|
||||||
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
|
|
||||||
|
|
||||||
Returns the binary code representing the encoded message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
binary_encoder = []
|
|
||||||
binary_code = None
|
|
||||||
|
|
||||||
stage_min_bin = "0.0"
|
|
||||||
stage_max_bin = "1.0"
|
|
||||||
|
|
||||||
stage_probs = {}
|
|
||||||
stage_probs[0] = [stage_min_bin, "0.1"]
|
|
||||||
stage_probs[1] = ["0.1", stage_max_bin]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if float_interval_max < bin2float(stage_probs[0][1]):
|
|
||||||
stage_min_bin = stage_probs[0][0]
|
|
||||||
stage_max_bin = stage_probs[0][1]
|
|
||||||
else:
|
|
||||||
stage_min_bin = stage_probs[1][0]
|
|
||||||
stage_max_bin = stage_probs[1][1]
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
binary_encoder.append(stage_probs)
|
|
||||||
|
|
||||||
stage_probs = self.process_stage_binary(float_interval_min,
|
|
||||||
float_interval_max,
|
|
||||||
stage_min_bin,
|
|
||||||
stage_max_bin)
|
|
||||||
|
|
||||||
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
|
|
||||||
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
|
|
||||||
if (bin2float(stage_probs[0][0]) >= float_interval_min) and (
|
|
||||||
bin2float(stage_probs[0][1]) < float_interval_max):
|
|
||||||
# The binary code is found.
|
|
||||||
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
|
|
||||||
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
|
|
||||||
# print("The binary code is : ", stage_probs[0][0])
|
|
||||||
binary_code = stage_probs[0][0]
|
|
||||||
break
|
|
||||||
elif (bin2float(stage_probs[1][0]) >= float_interval_min) and (
|
|
||||||
bin2float(stage_probs[1][1]) < float_interval_max):
|
|
||||||
# The binary code is found.
|
|
||||||
# print(stage_probs[1][0], bin2float(stage_probs[1][0]))
|
|
||||||
# print(stage_probs[1][1], bin2float(stage_probs[1][1]))
|
|
||||||
# print("The binary code is : ", stage_probs[1][0])
|
|
||||||
binary_code = stage_probs[1][0]
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
binary_encoder.append(stage_probs)
|
|
||||||
|
|
||||||
return binary_code, binary_encoder
|
|
||||||
|
|
||||||
def custom_binary_encoding(self, float_interval_min, float_interval_max):
|
|
||||||
"""
|
|
||||||
Find the binary representation of the floating punt number which lies in
|
|
||||||
[float_interval_min, float_interval_max).
|
|
||||||
|
|
||||||
float_interval_min: float
|
|
||||||
float_interval_max: float
|
|
||||||
"""
|
|
||||||
code = []
|
|
||||||
found = False
|
|
||||||
next_n = 0.5
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
while not found:
|
|
||||||
if n + next_n < float_interval_max:
|
|
||||||
code.append(1)
|
|
||||||
n += next_n
|
|
||||||
|
|
||||||
if n >= float_interval_min:
|
|
||||||
found = True
|
|
||||||
else:
|
|
||||||
code.append(0)
|
|
||||||
|
|
||||||
next_n /= 2
|
|
||||||
|
|
||||||
return ''.join(map(str, code))
|
|
||||||
|
|
||||||
|
|
||||||
def decode(self, encoded_msg, msg_length, probability_table):
|
|
||||||
"""
|
|
||||||
Decodes a message from a floating-point number.
|
|
||||||
|
|
||||||
encoded_msg: The floating-point value that encodes the message.
|
|
||||||
msg_length: Length of the message.
|
|
||||||
probability_table: The probability table.
|
|
||||||
|
|
||||||
Returns the decoded message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
decoder = []
|
|
||||||
|
|
||||||
decoded_msg = []
|
|
||||||
|
|
||||||
stage_min = Decimal(0.0)
|
|
||||||
stage_max = Decimal(1.0)
|
|
||||||
|
|
||||||
for idx in range(msg_length):
|
|
||||||
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
|
|
||||||
|
|
||||||
for msg_term, value in stage_probs.items():
|
|
||||||
if encoded_msg >= value[0] and encoded_msg <= value[1]:
|
|
||||||
break
|
|
||||||
|
|
||||||
decoded_msg.append(msg_term)
|
|
||||||
|
|
||||||
stage_min = stage_probs[msg_term][0]
|
|
||||||
stage_max = stage_probs[msg_term][1]
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
decoder.append(stage_probs)
|
|
||||||
|
|
||||||
if self.save_stages:
|
|
||||||
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
|
|
||||||
decoder.append(last_stage_probs)
|
|
||||||
|
|
||||||
return decoded_msg, decoder
|
|
||||||
|
|
||||||
|
|
||||||
def float2bin(float_num, num_bits=None):
|
|
||||||
"""
|
|
||||||
Converts a floating-point number into binary.
|
|
||||||
|
|
||||||
float_num: The floating-point number.
|
|
||||||
num_bits: The number of bits expected in the result. If None, then the number of bits depends on the number.
|
|
||||||
|
|
||||||
Returns the binary representation of the number.
|
|
||||||
"""
|
|
||||||
|
|
||||||
float_num = str(float_num)
|
|
||||||
if float_num.find(".") == -1:
|
|
||||||
# No decimals in the floating-point number.
|
|
||||||
integers = float_num
|
|
||||||
decimals = ""
|
|
||||||
else:
|
|
||||||
integers, decimals = float_num.split(".")
|
|
||||||
decimals = "0." + decimals
|
|
||||||
decimals = Decimal(decimals)
|
|
||||||
integers = int(integers)
|
|
||||||
|
|
||||||
result = ""
|
|
||||||
num_used_bits = 0
|
|
||||||
while True:
|
|
||||||
mul = decimals * 2
|
|
||||||
int_part = int(mul)
|
|
||||||
result = result + str(int_part)
|
|
||||||
num_used_bits = num_used_bits + 1
|
|
||||||
|
|
||||||
decimals = mul - int(mul)
|
|
||||||
if type(num_bits) is type(None):
|
|
||||||
if decimals == 0:
|
|
||||||
break
|
|
||||||
elif num_used_bits >= num_bits:
|
|
||||||
break
|
|
||||||
if type(num_bits) is type(None):
|
|
||||||
pass
|
|
||||||
elif len(result) < num_bits:
|
|
||||||
num_remaining_bits = num_bits - len(result)
|
|
||||||
result = result + "0" * num_remaining_bits
|
|
||||||
|
|
||||||
integers_bin = bin(integers)[2:]
|
|
||||||
result = str(integers_bin) + "." + str(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def bin2float(bin_num):
|
|
||||||
"""
|
|
||||||
Converts a binary number to a floating-point number.
|
|
||||||
|
|
||||||
bin_num: The binary number as a string.
|
|
||||||
|
|
||||||
Returns the floating-point representation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if bin_num.find(".") == -1:
|
|
||||||
# No decimals in the binary number.
|
|
||||||
integers = bin_num
|
|
||||||
decimals = ""
|
|
||||||
else:
|
|
||||||
integers, decimals = bin_num.split(".")
|
|
||||||
result = Decimal(0.0)
|
|
||||||
|
|
||||||
# Working with integers.
|
|
||||||
for idx, bit in enumerate(integers):
|
|
||||||
if bit == "0":
|
|
||||||
continue
|
|
||||||
mul = 2 ** idx
|
|
||||||
result = result + Decimal(mul)
|
|
||||||
|
|
||||||
# Working with decimals.
|
|
||||||
for idx, bit in enumerate(decimals):
|
|
||||||
if bit == "0":
|
|
||||||
continue
|
|
||||||
mul = Decimal(1.0) / Decimal((2 ** (idx + 1)))
|
|
||||||
result = result + mul
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
coder = CustomArithmeticEncoding({})
|
|
||||||
|
|
||||||
low = 0.00324
|
|
||||||
high = 0.357
|
|
||||||
|
|
||||||
# slow_code = coder.encode_binary(low, high)
|
|
||||||
fast_code = coder.custom_binary_encoding(low, high)
|
|
||||||
|
|
||||||
# print(slow_code)
|
|
||||||
print(fast_code)
|
|
||||||
40
uv.lock
generated
|
|
@ -1613,6 +1613,19 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175, upload-time = "2025-09-29T23:31:59.173Z" },
|
{ url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175, upload-time = "2025-09-29T23:31:59.173Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pandas-stubs"
|
||||||
|
version = "2.3.3.251201"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "types-pytz" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ee/a6/491b2af2cb3ee232765a73fb273a44cc1ac33b154f7745b2df2ee1dc4d01/pandas_stubs-2.3.3.251201.tar.gz", hash = "sha256:7a980f4f08cff2a6d7e4c6d6d26f4c5fcdb82a6f6531489b2f75c81567fe4536", size = 107787, upload-time = "2025-12-01T18:29:22.403Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl", hash = "sha256:eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5", size = 164638, upload-time = "2025-12-01T18:29:21.006Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pillow"
|
name = "pillow"
|
||||||
version = "12.0.0"
|
version = "12.0.0"
|
||||||
|
|
@ -1718,6 +1731,8 @@ dependencies = [
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "fsspec" },
|
{ name = "fsspec" },
|
||||||
{ name = "lorem" },
|
{ name = "lorem" },
|
||||||
|
{ name = "pandas-stubs" },
|
||||||
|
{ name = "seaborn" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
|
@ -1746,7 +1761,9 @@ requires-dist = [
|
||||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = ">=3.10.7" },
|
{ name = "matplotlib", marker = "extra == 'dev'", specifier = ">=3.10.7" },
|
||||||
{ name = "memray", marker = "extra == 'dev'", specifier = ">=1.19.1" },
|
{ name = "memray", marker = "extra == 'dev'", specifier = ">=1.19.1" },
|
||||||
{ name = "optuna", marker = "extra == 'dev'", specifier = "==4.5.0" },
|
{ name = "optuna", marker = "extra == 'dev'", specifier = "==4.5.0" },
|
||||||
|
{ name = "pandas-stubs", specifier = "~=2.3.3" },
|
||||||
{ name = "regex", marker = "extra == 'dataset'", specifier = ">=2025.11.3" },
|
{ name = "regex", marker = "extra == 'dataset'", specifier = ">=2025.11.3" },
|
||||||
|
{ name = "seaborn", specifier = ">=0.13.2" },
|
||||||
{ name = "torch", marker = "extra == 'dev'", specifier = "==2.9.0" },
|
{ name = "torch", marker = "extra == 'dev'", specifier = "==2.9.0" },
|
||||||
{ name = "torchdata", marker = "extra == 'dev'", specifier = "==0.7.1" },
|
{ name = "torchdata", marker = "extra == 'dev'", specifier = "==0.7.1" },
|
||||||
{ name = "torchvision", marker = "extra == 'dev'", specifier = "==0.24.0" },
|
{ name = "torchvision", marker = "extra == 'dev'", specifier = "==0.24.0" },
|
||||||
|
|
@ -2116,6 +2133,20 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
|
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "seaborn"
|
||||||
|
version = "0.13.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "matplotlib" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "pandas" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "80.9.0"
|
version = "80.9.0"
|
||||||
|
|
@ -2361,6 +2392,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/5e/dd/5cbf31f402f1cc0ab087c94d4669cfa55bd1e818688b910631e131d74e75/typer_slim-0.20.0-py3-none-any.whl", hash = "sha256:f42a9b7571a12b97dddf364745d29f12221865acef7a2680065f9bb29c7dc89d", size = 47087, upload-time = "2025-10-20T17:03:44.546Z" },
|
{ url = "https://files.pythonhosted.org/packages/5e/dd/5cbf31f402f1cc0ab087c94d4669cfa55bd1e818688b910631e131d74e75/typer_slim-0.20.0-py3-none-any.whl", hash = "sha256:f42a9b7571a12b97dddf364745d29f12221865acef7a2680065f9bb29c7dc89d", size = 47087, upload-time = "2025-10-20T17:03:44.546Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-pytz"
|
||||||
|
version = "2025.2.0.20251108"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.15.0"
|
version = "4.15.0"
|
||||||
|
|
|
||||||