feat: decompression
This commit is contained in:
parent
77b80914e8
commit
eec3b2b1e6
2 changed files with 52 additions and 10 deletions
|
|
@ -65,18 +65,38 @@ def compress(
|
|||
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||
|
||||
# Pack
|
||||
bits = binary_code.split(".", maxsplit=1)[1]
|
||||
val = int(bits, 2) if len(bits) else 0
|
||||
out_bytes = val.to_bytes((len(bits) + 7) // 8, "big")
|
||||
val = int(binary_code, 2) if len(binary_code) else 0
|
||||
out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
||||
|
||||
if output_file:
|
||||
print(f"Writing to {output_file}")
|
||||
with open(output_file, "wb") as file:
|
||||
file.write(out_bytes)
|
||||
with open(output_file, "w") as file:
|
||||
file.write(f"{len(byte_data)}\n")
|
||||
file.write(binary_code) # todo: temporary, decoding depends on binary string
|
||||
else:
|
||||
print(out_bytes)
|
||||
|
||||
|
||||
def bits_to_number(bits: str) -> float:
|
||||
n = 0
|
||||
for i, bit in enumerate(bits, start=1):
|
||||
n += int(bit) / (1 << i)
|
||||
return n
|
||||
|
||||
|
||||
def make_cumulative(probs):
|
||||
cumulative = []
|
||||
|
||||
total = 0
|
||||
|
||||
for prob in probs:
|
||||
low = total
|
||||
high = total + prob
|
||||
cumulative.append((low, high))
|
||||
total = high
|
||||
return cumulative
|
||||
|
||||
|
||||
def decompress(
|
||||
device,
|
||||
model_path: str,
|
||||
|
|
@ -84,10 +104,10 @@ def decompress(
|
|||
output_file: str | None = None
|
||||
):
|
||||
context_length = 128
|
||||
output = bytearray()
|
||||
|
||||
print("Reading in the data")
|
||||
with open(input_file, "rb") as f:
|
||||
with open(input_file, "r") as f:
|
||||
length = int(f.readline())
|
||||
bytes_data = f.read()
|
||||
|
||||
if len(bytes_data) == 0:
|
||||
|
|
@ -100,8 +120,29 @@ def decompress(
|
|||
model.eval()
|
||||
|
||||
print("Decompressing")
|
||||
ae = ArithmeticEncoding(frequency_table={0: 1})
|
||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
output = bytearray()
|
||||
|
||||
x = bits_to_number(bytes_data)
|
||||
|
||||
for _ in range(length):
|
||||
probs = model(context)
|
||||
cumulative = make_cumulative(probs)
|
||||
|
||||
for symbol, (low, high) in enumerate(cumulative):
|
||||
if low <= x < high:
|
||||
break
|
||||
|
||||
output.append(symbol)
|
||||
context.append(chr(symbol))
|
||||
|
||||
interval_low, interval_high = cumulative[symbol]
|
||||
interval_width = interval_high - interval_low
|
||||
x = (x - interval_low) / interval_width
|
||||
|
||||
if output_file is not None:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(output)
|
||||
return
|
||||
|
||||
print(output.decode('utf-8', errors='replace'))
|
||||
|
|
|
|||
|
|
@ -219,6 +219,7 @@ class CustomArithmeticEncoding:
|
|||
|
||||
return ''.join(map(str, code))
|
||||
|
||||
|
||||
def decode(self, encoded_msg, msg_length, probability_table):
|
||||
"""
|
||||
Decodes a message from a floating-point number.
|
||||
|
|
|
|||
Reference in a new issue