feat: decompression
This commit is contained in:
parent
77b80914e8
commit
eec3b2b1e6
2 changed files with 52 additions and 10 deletions
|
|
@ -65,18 +65,38 @@ def compress(
|
||||||
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||||
|
|
||||||
# Pack
|
# Pack
|
||||||
bits = binary_code.split(".", maxsplit=1)[1]
|
val = int(binary_code, 2) if len(binary_code) else 0
|
||||||
val = int(bits, 2) if len(bits) else 0
|
out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
||||||
out_bytes = val.to_bytes((len(bits) + 7) // 8, "big")
|
|
||||||
|
|
||||||
if output_file:
|
if output_file:
|
||||||
print(f"Writing to {output_file}")
|
print(f"Writing to {output_file}")
|
||||||
with open(output_file, "wb") as file:
|
with open(output_file, "w") as file:
|
||||||
file.write(out_bytes)
|
file.write(f"{len(byte_data)}\n")
|
||||||
|
file.write(binary_code) # todo: temporary, decoding depends on binary string
|
||||||
else:
|
else:
|
||||||
print(out_bytes)
|
print(out_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def bits_to_number(bits: str) -> float:
|
||||||
|
n = 0
|
||||||
|
for i, bit in enumerate(bits, start=1):
|
||||||
|
n += int(bit) / (1 << i)
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
def make_cumulative(probs):
|
||||||
|
cumulative = []
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for prob in probs:
|
||||||
|
low = total
|
||||||
|
high = total + prob
|
||||||
|
cumulative.append((low, high))
|
||||||
|
total = high
|
||||||
|
return cumulative
|
||||||
|
|
||||||
|
|
||||||
def decompress(
|
def decompress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
|
@ -84,10 +104,10 @@ def decompress(
|
||||||
output_file: str | None = None
|
output_file: str | None = None
|
||||||
):
|
):
|
||||||
context_length = 128
|
context_length = 128
|
||||||
output = bytearray()
|
|
||||||
|
|
||||||
print("Reading in the data")
|
print("Reading in the data")
|
||||||
with open(input_file, "rb") as f:
|
with open(input_file, "r") as f:
|
||||||
|
length = int(f.readline())
|
||||||
bytes_data = f.read()
|
bytes_data = f.read()
|
||||||
|
|
||||||
if len(bytes_data) == 0:
|
if len(bytes_data) == 0:
|
||||||
|
|
@ -100,8 +120,29 @@ def decompress(
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
print("Decompressing")
|
print("Decompressing")
|
||||||
ae = ArithmeticEncoding(frequency_table={0: 1})
|
|
||||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
|
||||||
|
|
||||||
context = deque([0] * context_length, maxlen=context_length)
|
context = deque([0] * context_length, maxlen=context_length)
|
||||||
|
output = bytearray()
|
||||||
|
|
||||||
|
x = bits_to_number(bytes_data)
|
||||||
|
|
||||||
|
for _ in range(length):
|
||||||
|
probs = model(context)
|
||||||
|
cumulative = make_cumulative(probs)
|
||||||
|
|
||||||
|
for symbol, (low, high) in enumerate(cumulative):
|
||||||
|
if low <= x < high:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.append(symbol)
|
||||||
|
context.append(chr(symbol))
|
||||||
|
|
||||||
|
interval_low, interval_high = cumulative[symbol]
|
||||||
|
interval_width = interval_high - interval_low
|
||||||
|
x = (x - interval_low) / interval_width
|
||||||
|
|
||||||
|
if output_file is not None:
|
||||||
|
with open(output_file, "wb") as f:
|
||||||
|
f.write(output)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(output.decode('utf-8', errors='replace'))
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,7 @@ class CustomArithmeticEncoding:
|
||||||
|
|
||||||
return ''.join(map(str, code))
|
return ''.join(map(str, code))
|
||||||
|
|
||||||
|
|
||||||
def decode(self, encoded_msg, msg_length, probability_table):
|
def decode(self, encoded_msg, msg_length, probability_table):
|
||||||
"""
|
"""
|
||||||
Decodes a message from a floating-point number.
|
Decodes a message from a floating-point number.
|
||||||
|
|
|
||||||
Reference in a new issue