fix: encoding binary now SUPAH fast
This commit is contained in:
parent
653b44804a
commit
77b80914e8
2 changed files with 44 additions and 35 deletions
|
|
@ -62,7 +62,7 @@ def compress(
|
|||
print("Getting encoded value")
|
||||
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||
print("Encoding in binary")
|
||||
binary_code, _ = AE.encode_binary(interval_min, interval_max)
|
||||
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||
|
||||
# Pack
|
||||
bits = binary_code.split(".", maxsplit=1)[1]
|
||||
|
|
@ -77,5 +77,31 @@ def compress(
|
|||
print(out_bytes)
|
||||
|
||||
|
||||
def decompress():
|
||||
return NotImplementedError("Decompression is not implemented yet")
|
||||
def decompress(
|
||||
device,
|
||||
model_path: str,
|
||||
input_file: str,
|
||||
output_file: str | None = None
|
||||
):
|
||||
context_length = 128
|
||||
output = bytearray()
|
||||
|
||||
print("Reading in the data")
|
||||
with open(input_file, "rb") as f:
|
||||
bytes_data = f.read()
|
||||
|
||||
if len(bytes_data) == 0:
|
||||
print("Input file is empty, nothing has to be done...")
|
||||
return
|
||||
|
||||
print("Loading the model")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
print("Decompressing")
|
||||
ae = ArithmeticEncoding(frequency_table={0: 1})
|
||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,40 +201,23 @@ class CustomArithmeticEncoding:
|
|||
float_interval_max: float
|
||||
"""
|
||||
code = []
|
||||
k = 1
|
||||
halves = [
|
||||
[0.0, 1 / 2],
|
||||
[1 / 2, 1.0]
|
||||
]
|
||||
found = False
|
||||
next_n = 0.5
|
||||
n = 0
|
||||
|
||||
i = 0
|
||||
|
||||
while i < 1024:
|
||||
k += 1
|
||||
i += 1
|
||||
|
||||
if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max:
|
||||
break
|
||||
if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max:
|
||||
break
|
||||
|
||||
# left interval, insert 0
|
||||
if float_interval_max < halves[0][1]:
|
||||
code.append(0)
|
||||
low = halves[0][0]
|
||||
high = halves[0][1]
|
||||
|
||||
else:
|
||||
while not found:
|
||||
if n + next_n < float_interval_max:
|
||||
code.append(1)
|
||||
low = halves[1][0]
|
||||
high = halves[1][1]
|
||||
n += next_n
|
||||
|
||||
halves[0][0] = low
|
||||
halves[0][1] = low + 1 / (1 << k)
|
||||
halves[1][0] = halves[0][1]
|
||||
halves[1][1] = high
|
||||
if n >= float_interval_min:
|
||||
found = True
|
||||
else:
|
||||
code.append(0)
|
||||
|
||||
return "0." + ''.join(map(str, code)), k
|
||||
next_n /= 2
|
||||
|
||||
return ''.join(map(str, code))
|
||||
|
||||
def decode(self, encoded_msg, msg_length, probability_table):
|
||||
"""
|
||||
|
|
@ -358,8 +341,8 @@ def bin2float(bin_num):
|
|||
if __name__ == "__main__":
|
||||
coder = CustomArithmeticEncoding({})
|
||||
|
||||
low = 0.25
|
||||
high = 0.5
|
||||
low = 0.00324
|
||||
high = 0.357
|
||||
|
||||
# slow_code = coder.encode_binary(low, high)
|
||||
fast_code = coder.custom_binary_encoding(low, high)
|
||||
|
|
|
|||
Reference in a new issue