fix: encoding binary now SUPAH fast
This commit is contained in:
parent
653b44804a
commit
77b80914e8
2 changed files with 44 additions and 35 deletions
|
|
@ -62,7 +62,7 @@ def compress(
|
||||||
print("Getting encoded value")
|
print("Getting encoded value")
|
||||||
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||||
print("Encoding in binary")
|
print("Encoding in binary")
|
||||||
binary_code, _ = AE.encode_binary(interval_min, interval_max)
|
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||||
|
|
||||||
# Pack
|
# Pack
|
||||||
bits = binary_code.split(".", maxsplit=1)[1]
|
bits = binary_code.split(".", maxsplit=1)[1]
|
||||||
|
|
@ -77,5 +77,31 @@ def compress(
|
||||||
print(out_bytes)
|
print(out_bytes)
|
||||||
|
|
||||||
|
|
||||||
def decompress():
|
def decompress(
|
||||||
return NotImplementedError("Decompression is not implemented yet")
|
device,
|
||||||
|
model_path: str,
|
||||||
|
input_file: str,
|
||||||
|
output_file: str | None = None
|
||||||
|
):
|
||||||
|
context_length = 128
|
||||||
|
output = bytearray()
|
||||||
|
|
||||||
|
print("Reading in the data")
|
||||||
|
with open(input_file, "rb") as f:
|
||||||
|
bytes_data = f.read()
|
||||||
|
|
||||||
|
if len(bytes_data) == 0:
|
||||||
|
print("Input file is empty, nothing has to be done...")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Loading the model")
|
||||||
|
model = torch.load(model_path, weights_only=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print("Decompressing")
|
||||||
|
ae = ArithmeticEncoding(frequency_table={0: 1})
|
||||||
|
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||||
|
|
||||||
|
context = deque([0] * context_length, maxlen=context_length)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -201,40 +201,23 @@ class CustomArithmeticEncoding:
|
||||||
float_interval_max: float
|
float_interval_max: float
|
||||||
"""
|
"""
|
||||||
code = []
|
code = []
|
||||||
k = 1
|
found = False
|
||||||
halves = [
|
next_n = 0.5
|
||||||
[0.0, 1 / 2],
|
n = 0
|
||||||
[1 / 2, 1.0]
|
|
||||||
]
|
|
||||||
|
|
||||||
i = 0
|
while not found:
|
||||||
|
if n + next_n < float_interval_max:
|
||||||
while i < 1024:
|
|
||||||
k += 1
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max:
|
|
||||||
break
|
|
||||||
if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max:
|
|
||||||
break
|
|
||||||
|
|
||||||
# left interval, insert 0
|
|
||||||
if float_interval_max < halves[0][1]:
|
|
||||||
code.append(0)
|
|
||||||
low = halves[0][0]
|
|
||||||
high = halves[0][1]
|
|
||||||
|
|
||||||
else:
|
|
||||||
code.append(1)
|
code.append(1)
|
||||||
low = halves[1][0]
|
n += next_n
|
||||||
high = halves[1][1]
|
|
||||||
|
|
||||||
halves[0][0] = low
|
if n >= float_interval_min:
|
||||||
halves[0][1] = low + 1 / (1 << k)
|
found = True
|
||||||
halves[1][0] = halves[0][1]
|
else:
|
||||||
halves[1][1] = high
|
code.append(0)
|
||||||
|
|
||||||
return "0." + ''.join(map(str, code)), k
|
next_n /= 2
|
||||||
|
|
||||||
|
return ''.join(map(str, code))
|
||||||
|
|
||||||
def decode(self, encoded_msg, msg_length, probability_table):
|
def decode(self, encoded_msg, msg_length, probability_table):
|
||||||
"""
|
"""
|
||||||
|
|
@ -358,8 +341,8 @@ def bin2float(bin_num):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
coder = CustomArithmeticEncoding({})
|
coder = CustomArithmeticEncoding({})
|
||||||
|
|
||||||
low = 0.25
|
low = 0.00324
|
||||||
high = 0.5
|
high = 0.357
|
||||||
|
|
||||||
# slow_code = coder.encode_binary(low, high)
|
# slow_code = coder.encode_binary(low, high)
|
||||||
fast_code = coder.custom_binary_encoding(low, high)
|
fast_code = coder.custom_binary_encoding(low, high)
|
||||||
|
|
|
||||||
Reference in a new issue