diff --git a/src/process.py b/src/process.py index 166644a..d77f30d 100644 --- a/src/process.py +++ b/src/process.py @@ -62,7 +62,7 @@ def compress( print("Getting encoded value") interval_min, interval_max, _ = AE.get_encoded_value(stage) print("Encoding in binary") - binary_code, _ = AE.encode_binary(interval_min, interval_max) + binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) # Pack bits = binary_code.split(".", maxsplit=1)[1] @@ -77,5 +77,31 @@ def compress( print(out_bytes) -def decompress(): - return NotImplementedError("Decompression is not implemented yet") +def decompress( + device, + model_path: str, + input_file: str, + output_file: str | None = None +): + context_length = 128 + output = bytearray() + + print("Reading in the data") + with open(input_file, "rb") as f: + bytes_data = f.read() + + if len(bytes_data) == 0: + print("Input file is empty, nothing has to be done...") + return + + print("Loading the model") + model = torch.load(model_path, weights_only=False) + model.to(device) + model.eval() + + print("Decompressing") + ae = ArithmeticEncoding(frequency_table={0: 1}) + stage_min, stage_max = Decimal(0), Decimal(1) + + context = deque([0] * context_length, maxlen=context_length) + diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py index 96e1242..006e53b 100644 --- a/src/utils/custom_ae.py +++ b/src/utils/custom_ae.py @@ -201,40 +201,23 @@ class CustomArithmeticEncoding: float_interval_max: float """ code = [] - k = 1 - halves = [ - [0.0, 1 / 2], - [1 / 2, 1.0] - ] + found = False + next_n = 0.5 + n = 0 - i = 0 - - while i < 1024: - k += 1 - i += 1 - - if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max: - break - if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max: - break - - # left interval, insert 0 - if float_interval_max < halves[0][1]: - code.append(0) - low = halves[0][0] - high = halves[0][1] - - else: + while not found: + if n + next_n < float_interval_max: code.append(1) - low = halves[1][0] - high = halves[1][1] + n += next_n - halves[0][0] = low - halves[0][1] = low + 1 / (1 << k) - halves[1][0] = halves[0][1] - halves[1][1] = high + if n >= float_interval_min: + found = True + else: + code.append(0) - return "0." + ''.join(map(str, code)), k + next_n /= 2 + + return ''.join(map(str, code)) def decode(self, encoded_msg, msg_length, probability_table): """ @@ -358,8 +341,8 @@ def bin2float(bin_num): if __name__ == "__main__": coder = CustomArithmeticEncoding({}) - low = 0.25 - high = 0.5 + low = 0.00324 + high = 0.357 # slow_code = coder.encode_binary(low, high) fast_code = coder.custom_binary_encoding(low, high)