From 5c26a52e1612bda9e058a9cb40f618f33064c151 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 21:13:09 +0100 Subject: [PATCH] feat (WIP): Compress --- pyproject.toml | 4 ++++ src/process.py | 65 ++++++++++++++++++++++++++++++++++++++++++++------ src/train.py | 2 +- uv.lock | 7 ++++++ 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa21be3..b100b4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "huggingface_hub==0.27.0", "fsspec==2024.9.0", "lorem>=0.1.1", + "arithmeticencodingpython", ] [project.optional-dependencies] @@ -21,3 +22,6 @@ dev = [ "torchdata==0.7.1", "torchvision==0.24.0", ] + +[tool.uv.sources] +arithmeticencodingpython = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git", rev = "60aad0528c57289218b241d75993574f31b90456" } diff --git a/src/process.py b/src/process.py index b2edda3..166644a 100644 --- a/src/process.py +++ b/src/process.py @@ -1,13 +1,22 @@ +from collections import deque +from decimal import Decimal + import torch +from pyae import ArithmeticEncoding +from tqdm import tqdm def compress( - device, - model_path: str, - output_file: str, - input_file: str | None = None + device, + model_path: str, + input_file: str | None = None, + output_file: str | None = None ): + # NOTE Hardcoded context length + context_length = 128 + # Get input to compress + print("Reading input") if input_file: with open(input_file, "rb") as file: byte_data = file.read() @@ -16,14 +25,56 @@ def compress( text = input() byte_data = text.encode('utf-8', errors='replace') + print("Converting to tensor") tensor = torch.tensor(list(byte_data), dtype=torch.long) - print(tensor) # Get model + print("Loading model") model = torch.load(model_path, weights_only=False) + model.to(device) + model.eval() - # TODO Feed to model for compression, store result - return + # Init AE + print("Initializing AE") + AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used + stage_min, stage_max = Decimal(0), Decimal(1) + stage = None + + # Compress + context = deque([0] * context_length, maxlen=context_length) + for byte in tqdm(tensor.tolist(), desc="Compressing"): + context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) + + with torch.inference_mode(): + logits = model(context_tensor) + probabilities = torch.softmax(logits[0], dim=-1) + probabilities = probabilities.detach().cpu().numpy() + + eps = 1e-10 + frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = AE.get_probability_table(frequency_table) + + stage = AE.process_stage(probability_table, stage_min, stage_max) + stage_min, stage_max = stage[byte] + + context.append(byte) + + print("Getting encoded value") + interval_min, interval_max, _ = AE.get_encoded_value(stage) + print("Encoding in binary") + binary_code, _ = AE.encode_binary(interval_min, interval_max) + + # Pack + bits = binary_code.split(".", maxsplit=1)[1] + val = int(bits, 2) if len(bits) else 0 + out_bytes = val.to_bytes((len(bits) + 7) // 8, "big") + + if output_file: + print(f"Writing to {output_file}") + with open(output_file, "wb") as file: + file.write(out_bytes) + else: + print(out_bytes) def decompress(): diff --git a/src/train.py b/src/train.py index ee4a99a..f359fba 100644 --- a/src/train.py +++ b/src/train.py @@ -19,7 +19,7 @@ def train( model_path: str | None = None, model_out: str | None = None ): - batch_size = 2 + batch_size = 64 assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" diff --git a/uv.lock b/uv.lock index bf27f7f..24dafc2 100644 --- a/uv.lock +++ b/uv.lock @@ -163,6 +163,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" }, ] +[[package]] +name = "arithmeticencodingpython" +version = "1.0.0" +source = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456#60aad0528c57289218b241d75993574f31b90456" } + [[package]] name = "attrs" version = "25.4.0" @@ -1621,6 +1626,7 @@ name = "project-ml" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "arithmeticencodingpython" }, { name = "datasets" }, { name = "fsspec" }, { name = "huggingface-hub" }, @@ -1640,6 +1646,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "arithmeticencodingpython", git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456" }, { name = "datasets", specifier = ">=3.2.0" }, { name = "fsspec", specifier = "==2024.9.0" }, { name = "huggingface-hub", specifier = "==0.27.0" },