WIP: Attempt at switching
This commit is contained in:
parent
817c16bde4
commit
2c0e0c2278
1 changed files with 35 additions and 31 deletions
|
|
@ -36,12 +36,16 @@ def compress(
|
||||||
|
|
||||||
# Init AE
|
# Init AE
|
||||||
print("Initializing AE")
|
print("Initializing AE")
|
||||||
AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used
|
|
||||||
|
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||||
|
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||||
|
|
||||||
|
context = deque([0] * context_length, maxlen=context_length)
|
||||||
|
|
||||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||||
stage = None
|
stage = None
|
||||||
|
|
||||||
# Compress
|
# Compress
|
||||||
context = deque([0] * context_length, maxlen=context_length)
|
|
||||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
@ -52,29 +56,29 @@ def compress(
|
||||||
|
|
||||||
eps = 1e-10
|
eps = 1e-10
|
||||||
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||||
|
probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities))
|
||||||
probability_table = AE.get_probability_table(frequency_table)
|
probability_table = AE.get_probability_table(frequency_table)
|
||||||
|
|
||||||
stage = AE.process_stage(probability_table, stage_min, stage_max)
|
enc.write(frequency_table, byte)
|
||||||
stage_min, stage_max = stage[byte]
|
|
||||||
|
|
||||||
context.append(byte)
|
context.append(byte)
|
||||||
|
|
||||||
print("Getting encoded value")
|
# print("Getting encoded value")
|
||||||
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
# interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||||
print("Encoding in binary")
|
# print("Encoding in binary")
|
||||||
binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
# binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||||
|
|
||||||
# Pack
|
# Pack
|
||||||
val = int(binary_code, 2) if len(binary_code) else 0
|
# val = int(binary_code, 2) if len(binary_code) else 0
|
||||||
out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
||||||
|
|
||||||
if output_file:
|
# if output_file:
|
||||||
print(f"Writing to {output_file}")
|
# print(f"Writing to {output_file}")
|
||||||
with open(output_file, "w") as file:
|
# with open(output_file, "w") as file:
|
||||||
file.write(f"{len(byte_data)}\n")
|
# file.write(f"{len(byte_data)}\n")
|
||||||
file.write(binary_code) # todo: temporary, decoding depends on binary string
|
# file.write(binary_code) # todo: temporary, decoding depends on binary string
|
||||||
else:
|
# else:
|
||||||
print(out_bytes)
|
# print(out_bytes)
|
||||||
|
|
||||||
|
|
||||||
def bits_to_number(bits: str) -> float:
|
def bits_to_number(bits: str) -> float:
|
||||||
|
|
|
||||||
Reference in a new issue