From 5c26a52e1612bda9e058a9cb40f618f33064c151 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 21:13:09 +0100 Subject: [PATCH 01/11] feat (WIP): Compress --- pyproject.toml | 4 ++++ src/process.py | 65 ++++++++++++++++++++++++++++++++++++++++++++------ src/train.py | 2 +- uv.lock | 7 ++++++ 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa21be3..b100b4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "huggingface_hub==0.27.0", "fsspec==2024.9.0", "lorem>=0.1.1", + "arithmeticencodingpython", ] [project.optional-dependencies] @@ -21,3 +22,6 @@ dev = [ "torchdata==0.7.1", "torchvision==0.24.0", ] + +[tool.uv.sources] +arithmeticencodingpython = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git", rev = "60aad0528c57289218b241d75993574f31b90456" } diff --git a/src/process.py b/src/process.py index b2edda3..166644a 100644 --- a/src/process.py +++ b/src/process.py @@ -1,13 +1,22 @@ +from collections import deque +from decimal import Decimal + import torch +from pyae import ArithmeticEncoding +from tqdm import tqdm def compress( - device, - model_path: str, - output_file: str, - input_file: str | None = None + device, + model_path: str, + input_file: str | None = None, + output_file: str | None = None ): + # NOTE Hardcoded context length + context_length = 128 + # Get input to compress + print("Reading input") if input_file: with open(input_file, "rb") as file: byte_data = file.read() @@ -16,14 +25,56 @@ def compress( text = input() byte_data = text.encode('utf-8', errors='replace') + print("Converting to tensor") tensor = torch.tensor(list(byte_data), dtype=torch.long) - print(tensor) # Get model + print("Loading model") model = torch.load(model_path, weights_only=False) + model.to(device) + model.eval() - # TODO Feed to model for compression, store result - return + # Init AE + print("Initializing AE") + AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used + stage_min, stage_max = Decimal(0), Decimal(1) + stage = None + + # Compress + context = deque([0] * context_length, maxlen=context_length) + for byte in tqdm(tensor.tolist(), desc="Compressing"): + context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) + + with torch.inference_mode(): + logits = model(context_tensor) + probabilities = torch.softmax(logits[0], dim=-1) + probabilities = probabilities.detach().cpu().numpy() + + eps = 1e-10 + frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = AE.get_probability_table(frequency_table) + + stage = AE.process_stage(probability_table, stage_min, stage_max) + stage_min, stage_max = stage[byte] + + context.append(byte) + + print("Getting encoded value") + interval_min, interval_max, _ = AE.get_encoded_value(stage) + print("Encoding in binary") + binary_code, _ = AE.encode_binary(interval_min, interval_max) + + # Pack + bits = binary_code.split(".", maxsplit=1)[1] + val = int(bits, 2) if len(bits) else 0 + out_bytes = val.to_bytes((len(bits) + 7) // 8, "big") + + if output_file: + print(f"Writing to {output_file}") + with open(output_file, "wb") as file: + file.write(out_bytes) + else: + print(out_bytes) def decompress(): diff --git a/src/train.py b/src/train.py index ee4a99a..f359fba 100644 --- a/src/train.py +++ b/src/train.py @@ -19,7 +19,7 @@ def train( model_path: str | None = None, model_out: str | None = None ): - batch_size = 2 + batch_size = 64 assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" diff --git a/uv.lock b/uv.lock index bf27f7f..24dafc2 100644 --- a/uv.lock +++ b/uv.lock @@ -163,6 +163,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" }, ] +[[package]] +name = "arithmeticencodingpython" +version = "1.0.0" +source = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456#60aad0528c57289218b241d75993574f31b90456" } + [[package]] name = "attrs" version = "25.4.0" @@ -1621,6 +1626,7 @@ name = "project-ml" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "arithmeticencodingpython" }, { name = "datasets" }, { name = "fsspec" }, { name = "huggingface-hub" }, @@ -1640,6 +1646,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "arithmeticencodingpython", git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456" }, { name = "datasets", specifier = ">=3.2.0" }, { name = "fsspec", specifier = "==2024.9.0" }, { name = "huggingface-hub", specifier = "==0.27.0" }, From 5de81819593ef830e986dde5651f5b5300f739d7 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 14:41:20 +0100 Subject: [PATCH 02/11] chore: Add Nix configs --- config/configuration.nix | 207 +++++++++++++++++++++++++++++++++++++++ config/flake.lock | 151 ++++++++++++++++++++++++++++ config/flake.nix | 66 +++++++++++++ 3 files changed, 424 insertions(+) create mode 100644 config/configuration.nix create mode 100644 config/flake.lock create mode 100644 config/flake.nix diff --git a/config/configuration.nix b/config/configuration.nix new file mode 100644 index 0000000..649767d --- /dev/null +++ b/config/configuration.nix @@ -0,0 +1,207 @@ +# Edit this configuration file to define what should be installed on +# your system. Help is available in the configuration.nix(5) man page, on +# https://search.nixos.org/options and in the NixOS manual (`nixos-help`). + +{ config, lib, pkgs, ... }: + +{ + imports = + [ # Include the results of the hardware scan. + ./hardware-configuration.nix + ]; + + # Use the systemd-boot EFI boot loader. + boot.loader = { + systemd-boot.enable = true; + efi = { + efiSysMountPoint = "/boot/efi"; + canTouchEfiVariables = true; + }; + }; + + networking.hostName = "MachineLearning"; # Define your hostname. + # Pick only one of the below networking options. + # networking.wireless.enable = true; # Enables wireless support via wpa_supplicant. + # networking.networkmanager.enable = true; # Easiest to use and most distros use this by default. + + # Set your time zone. + time.timeZone = "Europe/Brussels"; + + # Configure network proxy if necessary + # networking.proxy.default = "http://user:password@proxy:port/"; + # networking.proxy.noProxy = "127.0.0.1,localhost,internal.domain"; + + # Select internationalisation properties. + # i18n.defaultLocale = "en_US.UTF-8"; + # console = { + # font = "Lat2-Terminus16"; + # keyMap = "us"; + # useXkbConfig = true; # use xkb.options in tty. + # }; + + # Enable the X11 windowing system. + services.xserver = { + #enable = true; + videoDrivers = [ + "nvidia" + ]; + }; + + # Configure keymap in X11 + # services.xserver.xkb.layout = "us"; + # services.xserver.xkb.options = "eurosign:e,caps:escape"; + + # Enable CUPS to print documents. + # services.printing.enable = true; + + # Enable sound. + # services.pulseaudio.enable = true; + # OR + # services.pipewire = { + # enable = true; + # pulse.enable = true; + # }; + + # Enable touchpad support (enabled default in most desktopManager). + # services.libinput.enable = true; + + # Define a user account. Don't forget to set a password with ‘passwd’. + # users.users.alice = { + # isNormalUser = true; + # extraGroups = [ "wheel" ]; # Enable ‘sudo’ for the user. + # packages = with pkgs; [ + # tree + # ]; + # }; + users.users = { + admin = { + description = "System Administrator"; + isNormalUser = true; + extraGroups = [ + config.users.groups.wheel.name # Enable 'sudo' for the user. + ]; + initialPassword = "ChangeMe"; + + openssh.authorizedKeys.keys = [ + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFdkZTYhBdUJ1YXx/2Iek0XC/jkbdxg37GORpXUgP2NO" + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGNSav7u6OxtxlAzq170/HuzE8cGvCULVGAiragtS5T6" + ]; + }; + + ml = { + description = "Machine Learning benchmarks"; + isNormalUser = true; + + openssh.authorizedKeys.keys = [ + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFdkZTYhBdUJ1YXx/2Iek0XC/jkbdxg37GORpXUgP2NO" + "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGNSav7u6OxtxlAzq170/HuzE8cGvCULVGAiragtS5T6" + ]; + }; + }; + + # programs.firefox.enable = true; + + # List packages installed in system profile. + # You can use https://search.nixos.org/ to find more packages (and options). + environment.systemPackages = with pkgs; [ + vim + curl + git + wget + tmux + ]; + + hardware = { + graphics = { + enable = true; + enable32Bit = true; + extraPackages = with pkgs; [ + intel-ocl + intel-compute-runtime + intel-graphics-compiler + opencl-clhpp + opencl-headers + ocl-icd + ]; + }; + nvidia = { + modesetting.enable = true; + powerManagement.enable = false; + powerManagement.finegrained = false; + open = false; + nvidiaSettings = false; + package = config.boot.kernelPackages.nvidiaPackages.stable; + +# prime = { +# nvidiaBusId = "PCI:1:0:0"; +# intelBusId = "PCI:0:2:0"; +# }; + }; + }; + + # Some programs need SUID wrappers, can be configured further or are + # started in user sessions. + # programs.mtr.enable = true; + # programs.gnupg.agent = { + # enable = true; + # enableSSHSupport = true; + # }; + + nix.settings = { + substituters = [ + "https://cache.nixos-cuda.org" + ]; + trusted-public-keys = [ + "cache.nixos-cuda.org:74DUi4Ye579gUqzH4ziL9IyiJBlDpMRn9MBN8oNan9M=" + ]; + experimental-features = [ + "nix-command" + "flakes" + ]; + }; + + nixpkgs.config.allowUnfree = true; + + # List services that you want to enable: + + # Enable the OpenSSH daemon. + services.openssh = { + enable = true; + settings = { + PasswordAuthentication = false; + PermitRootLogin = "no"; + }; + }; + + # Open ports in the firewall. + # networking.firewall.allowedTCPPorts = [ ... ]; + # networking.firewall.allowedUDPPorts = [ ... ]; + # Or disable the firewall altogether. + # networking.firewall.enable = false; + + # Copy the NixOS configuration file and link it from the resulting system + # (/run/current-system/configuration.nix). This is useful in case you + # accidentally delete configuration.nix. + # system.copySystemConfiguration = true; + + # This option defines the first version of NixOS you have installed on this particular machine, + # and is used to maintain compatibility with application data (e.g. databases) created on older NixOS versions. + # + # Most users should NEVER change this value after the initial install, for any reason, + # even if you've upgraded your system to a new NixOS release. + # + # This value does NOT affect the Nixpkgs version your packages and OS are pulled from, + # so changing it will NOT upgrade your system - see https://nixos.org/manual/nixos/stable/#sec-upgrading for how + # to actually do that. + # + # This value being lower than the current NixOS release does NOT mean your system is + # out of date, out of support, or vulnerable. + # + # Do NOT change this value unless you have manually inspected all the changes it would make to your configuration, + # and migrated your data accordingly. + # + # For more information, see `man configuration.nix` or https://nixos.org/manual/nixos/stable/options#opt-system.stateVersion . + system.stateVersion = "25.05"; # Did you read the comment? + +} + diff --git a/config/flake.lock b/config/flake.lock new file mode 100644 index 0000000..16f7df5 --- /dev/null +++ b/config/flake.lock @@ -0,0 +1,151 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": [ + "nix-jetbrains-plugins", + "systems" + ] + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-jetbrains-plugins": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs", + "systems": "systems_2" + }, + "locked": { + "lastModified": 1765025946, + "narHash": "sha256-ZSeAc3h08Lv67gbUjDMK6GTrQgYsrNpFNJEavCPxN8I=", + "owner": "theCapypara", + "repo": "nix-jetbrains-plugins", + "rev": "b861755ca1f4f7633ffdddc5608c32632cecebc3", + "type": "github" + }, + "original": { + "owner": "theCapypara", + "repo": "nix-jetbrains-plugins", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1757745802, + "narHash": "sha256-hLEO2TPj55KcUFUU1vgtHE9UEIOjRcH/4QbmfHNF820=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c23193b943c6c689d70ee98ce3128239ed9e32d1", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-unstable": { + "locked": { + "lastModified": 1765186076, + "narHash": "sha256-hM20uyap1a0M9d344I692r+ik4gTMyj60cQWO+hAYP8=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "addf7cf5f383a3101ecfba091b98d0a1263dc9b8", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1764939437, + "narHash": "sha256-4TLFHUwXraw9Df5mXC/vCrJgb50CRr3CzUzF0Mn3CII=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "00d2457e2f608b4be6fe8b470b0a36816324b0ae", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nix-jetbrains-plugins": "nix-jetbrains-plugins", + "nixpkgs": "nixpkgs_2", + "nixpkgs-unstable": "nixpkgs-unstable" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/config/flake.nix b/config/flake.nix new file mode 100644 index 0000000..da326a6 --- /dev/null +++ b/config/flake.nix @@ -0,0 +1,66 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + nixpkgs-unstable.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + nix-jetbrains-plugins.url = "github:theCapypara/nix-jetbrains-plugins"; + }; + + outputs = { self, nixpkgs, nixpkgs-unstable, flake-utils, nix-jetbrains-plugins }: + flake-utils.lib.eachDefaultSystem (system: let + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + }; + pkgs-unstable = import nixpkgs-unstable { + inherit system; + config.allowUnfree = true; + }; + + python-packages = p: with p; [ + numpy + ]; + + pluginList = [ + "be.ugent.piedcler.dodona" + "com.github.copilot" + "com.google.tools.ij.aiplugin" + "IdeaVIM" + ]; + + mkShell = pkgs.mkShell.override { + stdenv = pkgs.stdenvAdapters.useMoldLinker pkgs.stdenv; + }; + in { + devShells.default = pkgs.mkShell { + packages = (with pkgs; [ + python311 + (python-packages python311Packages) + + # CUDA + git gitRepo gnupg autoconf curl + procps gnumake util-linux m4 gperf unzip + cudatoolkit linuxPackages.nvidia_x11 + libGLU libGL + xorg.libXi xorg.libXmu freeglut + xorg.libXext xorg.libX11 xorg.libXv xorg.libXrandr zlib + ncurses5 stdenv.cc binutils + ]) ++ (with pkgs-unstable; [ + uv + ]) ++ (with nix-jetbrains-plugins.lib."${system}"; [ + # Editor of your choice + #(buildIdeWithPlugins pkgs-unstable.jetbrains "pycharm-professional" pluginList) + ]); + + # CUDA + CUDA_PATH = pkgs.cudatoolkit; + # ImportError: libstdc++.so.6: cannot open shared object file: No such file or directory + LD_LIBRARY_PATH = "${pkgs.linuxPackages.nvidia_x11}/lib:${pkgs.ncurses5}/lib:${pkgs.libGL}/lib/:${pkgs.stdenv.cc.cc.lib}/lib/:${pkgs.glibc}/lib"; + EXTRA_LDFLAGS = "-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib"; + EXTRA_CCFLAGS = "-I/usr/include"; + + # Stop uv from downloading Python binaries automatically if needed. + UV_PYTHON_DOWNLOADS = "never"; + }; + }); +} From 653b44804ad3059fcef03aba8c51ad52723c7db2 Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Thu, 11 Dec 2025 15:22:28 +0100 Subject: [PATCH 03/11] feat: faster encoding binary --- src/utils/custom_ae.py | 368 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 368 insertions(+) create mode 100644 src/utils/custom_ae.py diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py new file mode 100644 index 0000000..96e1242 --- /dev/null +++ b/src/utils/custom_ae.py @@ -0,0 +1,368 @@ +from decimal import Decimal + + +class CustomArithmeticEncoding: + """ + ArithmeticEncoding is a class for building the arithmetic encoding. + """ + + def __init__(self, frequency_table, save_stages=False): + """ + frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency. + save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large + """ + + self.save_stages = save_stages + if (save_stages == True): + print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.") + + self.probability_table = self.get_probability_table(frequency_table) + + def get_probability_table(self, frequency_table): + """ + Calculates the probability table out of the frequency table. + + frequency_table: A table of the term frequencies. + + Returns the probability table. + """ + total_frequency = sum(list(frequency_table.values())) + + probability_table = {} + for key, value in frequency_table.items(): + probability_table[key] = value / total_frequency + + return probability_table + + def get_encoded_value(self, last_stage_probs): + """ + After encoding the entire message, this method returns the single value that represents the entire message. + + last_stage_probs: A list of the probabilities in the last stage. + + Returns the minimum and maximum probabilites in the last stage in addition to the value encoding the message. + """ + last_stage_probs = list(last_stage_probs.values()) + last_stage_values = [] + for sublist in last_stage_probs: + for element in sublist: + last_stage_values.append(element) + + last_stage_min = min(last_stage_values) + last_stage_max = max(last_stage_values) + encoded_value = (last_stage_min + last_stage_max) / 2 + + return last_stage_min, last_stage_max, encoded_value + + def process_stage(self, probability_table, stage_min, stage_max): + """ + Processing a stage in the encoding/decoding process. + + probability_table: The probability table. + stage_min: The minumim probability of the current stage. + stage_max: The maximum probability of the current stage. + + Returns the probabilities in the stage. + """ + + stage_probs = {} + stage_domain = stage_max - stage_min + for term_idx in range(len(probability_table.items())): + term = list(probability_table.keys())[term_idx] + term_prob = Decimal(probability_table[term]) + cum_prob = term_prob * stage_domain + stage_min + stage_probs[term] = [stage_min, cum_prob] + stage_min = cum_prob + return stage_probs + + def encode(self, msg, probability_table): + """ + Encodes a message using arithmetic encoding. + + msg: The message to be encoded. + probability_table: The probability table. + + Returns the encoder, the floating-point value representing the encoded message, and the maximum and minimum values of the interval in which the floating-point value falls. + """ + + msg = list(msg) + + encoder = [] + + stage_min = Decimal(0.0) + stage_max = Decimal(1.0) + + for msg_term_idx in range(len(msg)): + stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + msg_term = msg[msg_term_idx] + stage_min = stage_probs[msg_term][0] + stage_max = stage_probs[msg_term][1] + + if self.save_stages: + encoder.append(stage_probs) + + last_stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + if self.save_stages: + encoder.append(last_stage_probs) + + interval_min_value, interval_max_value, encoded_msg = self.get_encoded_value(last_stage_probs) + + return encoded_msg, encoder, interval_min_value, interval_max_value + + def process_stage_binary(self, float_interval_min, float_interval_max, stage_min_bin, stage_max_bin): + """ + Processing a stage in the encoding/decoding process. + + float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located. + float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located. + stage_min_bin: The minimum binary number in the current stage. + stage_max_bin: The maximum binary number in the current stage. + + Returns the probabilities of the terms in this stage. There are only 2 terms. + """ + + stage_mid_bin = stage_min_bin + "1" + stage_min_bin = stage_min_bin + "0" + + stage_probs = {} + stage_probs[0] = [stage_min_bin, stage_mid_bin] + stage_probs[1] = [stage_mid_bin, stage_max_bin] + + return stage_probs + + def encode_binary(self, float_interval_min, float_interval_max): + """ + Calculates the binary code that represents the floating-point value that encodes the message. + + float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located. + float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located. + + Returns the binary code representing the encoded message. + """ + + binary_encoder = [] + binary_code = None + + stage_min_bin = "0.0" + stage_max_bin = "1.0" + + stage_probs = {} + stage_probs[0] = [stage_min_bin, "0.1"] + stage_probs[1] = ["0.1", stage_max_bin] + + while True: + if float_interval_max < bin2float(stage_probs[0][1]): + stage_min_bin = stage_probs[0][0] + stage_max_bin = stage_probs[0][1] + else: + stage_min_bin = stage_probs[1][0] + stage_max_bin = stage_probs[1][1] + + if self.save_stages: + binary_encoder.append(stage_probs) + + stage_probs = self.process_stage_binary(float_interval_min, + float_interval_max, + stage_min_bin, + stage_max_bin) + + # print(stage_probs[0][0], bin2float(stage_probs[0][0])) + # print(stage_probs[0][1], bin2float(stage_probs[0][1])) + if (bin2float(stage_probs[0][0]) >= float_interval_min) and ( + bin2float(stage_probs[0][1]) < float_interval_max): + # The binary code is found. + # print(stage_probs[0][0], bin2float(stage_probs[0][0])) + # print(stage_probs[0][1], bin2float(stage_probs[0][1])) + # print("The binary code is : ", stage_probs[0][0]) + binary_code = stage_probs[0][0] + break + elif (bin2float(stage_probs[1][0]) >= float_interval_min) and ( + bin2float(stage_probs[1][1]) < float_interval_max): + # The binary code is found. + # print(stage_probs[1][0], bin2float(stage_probs[1][0])) + # print(stage_probs[1][1], bin2float(stage_probs[1][1])) + # print("The binary code is : ", stage_probs[1][0]) + binary_code = stage_probs[1][0] + break + + if self.save_stages: + binary_encoder.append(stage_probs) + + return binary_code, binary_encoder + + def custom_binary_encoding(self, float_interval_min, float_interval_max): + """ + Find the binary representation of the floating punt number which lies in + [float_interval_min, float_interval_max). + + float_interval_min: float + float_interval_max: float + """ + code = [] + k = 1 + halves = [ + [0.0, 1 / 2], + [1 / 2, 1.0] + ] + + i = 0 + + while i < 1024: + k += 1 + i += 1 + + if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max: + break + if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max: + break + + # left interval, insert 0 + if float_interval_max < halves[0][1]: + code.append(0) + low = halves[0][0] + high = halves[0][1] + + else: + code.append(1) + low = halves[1][0] + high = halves[1][1] + + halves[0][0] = low + halves[0][1] = low + 1 / (1 << k) + halves[1][0] = halves[0][1] + halves[1][1] = high + + return "0." + ''.join(map(str, code)), k + + def decode(self, encoded_msg, msg_length, probability_table): + """ + Decodes a message from a floating-point number. + + encoded_msg: The floating-point value that encodes the message. + msg_length: Length of the message. + probability_table: The probability table. + + Returns the decoded message. + """ + + decoder = [] + + decoded_msg = [] + + stage_min = Decimal(0.0) + stage_max = Decimal(1.0) + + for idx in range(msg_length): + stage_probs = self.process_stage(probability_table, stage_min, stage_max) + + for msg_term, value in stage_probs.items(): + if encoded_msg >= value[0] and encoded_msg <= value[1]: + break + + decoded_msg.append(msg_term) + + stage_min = stage_probs[msg_term][0] + stage_max = stage_probs[msg_term][1] + + if self.save_stages: + decoder.append(stage_probs) + + if self.save_stages: + last_stage_probs = self.process_stage(probability_table, stage_min, stage_max) + decoder.append(last_stage_probs) + + return decoded_msg, decoder + + +def float2bin(float_num, num_bits=None): + """ + Converts a floating-point number into binary. + + float_num: The floating-point number. + num_bits: The number of bits expected in the result. If None, then the number of bits depends on the number. + + Returns the binary representation of the number. + """ + + float_num = str(float_num) + if float_num.find(".") == -1: + # No decimals in the floating-point number. + integers = float_num + decimals = "" + else: + integers, decimals = float_num.split(".") + decimals = "0." + decimals + decimals = Decimal(decimals) + integers = int(integers) + + result = "" + num_used_bits = 0 + while True: + mul = decimals * 2 + int_part = int(mul) + result = result + str(int_part) + num_used_bits = num_used_bits + 1 + + decimals = mul - int(mul) + if type(num_bits) is type(None): + if decimals == 0: + break + elif num_used_bits >= num_bits: + break + if type(num_bits) is type(None): + pass + elif len(result) < num_bits: + num_remaining_bits = num_bits - len(result) + result = result + "0" * num_remaining_bits + + integers_bin = bin(integers)[2:] + result = str(integers_bin) + "." + str(result) + return result + + +def bin2float(bin_num): + """ + Converts a binary number to a floating-point number. + + bin_num: The binary number as a string. + + Returns the floating-point representation. + """ + + if bin_num.find(".") == -1: + # No decimals in the binary number. + integers = bin_num + decimals = "" + else: + integers, decimals = bin_num.split(".") + result = Decimal(0.0) + + # Working with integers. + for idx, bit in enumerate(integers): + if bit == "0": + continue + mul = 2 ** idx + result = result + Decimal(mul) + + # Working with decimals. + for idx, bit in enumerate(decimals): + if bit == "0": + continue + mul = Decimal(1.0) / Decimal((2 ** (idx + 1))) + result = result + mul + return result + + +if __name__ == "__main__": + coder = CustomArithmeticEncoding({}) + + low = 0.25 + high = 0.5 + + # slow_code = coder.encode_binary(low, high) + fast_code = coder.custom_binary_encoding(low, high) + + # print(slow_code) + print(fast_code) From 77b80914e8bcb8e316258b7ca04b297484669baa Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Thu, 11 Dec 2025 20:36:24 +0100 Subject: [PATCH 04/11] fix: encoding binary now SUPAH fast --- src/process.py | 32 +++++++++++++++++++++++++--- src/utils/custom_ae.py | 47 ++++++++++++++---------------------------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/process.py b/src/process.py index 166644a..d77f30d 100644 --- a/src/process.py +++ b/src/process.py @@ -62,7 +62,7 @@ def compress( print("Getting encoded value") interval_min, interval_max, _ = AE.get_encoded_value(stage) print("Encoding in binary") - binary_code, _ = AE.encode_binary(interval_min, interval_max) + binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) # Pack bits = binary_code.split(".", maxsplit=1)[1] @@ -77,5 +77,31 @@ def compress( print(out_bytes) -def decompress(): - return NotImplementedError("Decompression is not implemented yet") +def decompress( + device, + model_path: str, + input_file: str, + output_file: str | None = None +): + context_length = 128 + output = bytearray() + + print("Reading in the data") + with open(input_file, "rb") as f: + bytes_data = f.read() + + if len(bytes_data) == 0: + print("Input file is empty, nothing has to be done...") + return + + print("Loading the model") + model = torch.load(model_path, weights_only=False) + model.to(device) + model.eval() + + print("Decompressing") + ae = ArithmeticEncoding(frequency_table={0: 1}) + stage_min, stage_max = Decimal(0), Decimal(1) + + context = deque([0] * context_length, maxlen=context_length) + diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py index 96e1242..006e53b 100644 --- a/src/utils/custom_ae.py +++ b/src/utils/custom_ae.py @@ -201,40 +201,23 @@ class CustomArithmeticEncoding: float_interval_max: float """ code = [] - k = 1 - halves = [ - [0.0, 1 / 2], - [1 / 2, 1.0] - ] + found = False + next_n = 0.5 + n = 0 - i = 0 - - while i < 1024: - k += 1 - i += 1 - - if halves[0][0] >= float_interval_min and halves[0][1] < float_interval_max: - break - if halves[1][0] >= float_interval_min and halves[1][1] < float_interval_max: - break - - # left interval, insert 0 - if float_interval_max < halves[0][1]: - code.append(0) - low = halves[0][0] - high = halves[0][1] - - else: + while not found: + if n + next_n < float_interval_max: code.append(1) - low = halves[1][0] - high = halves[1][1] + n += next_n - halves[0][0] = low - halves[0][1] = low + 1 / (1 << k) - halves[1][0] = halves[0][1] - halves[1][1] = high + if n >= float_interval_min: + found = True + else: + code.append(0) - return "0." + ''.join(map(str, code)), k + next_n /= 2 + + return ''.join(map(str, code)) def decode(self, encoded_msg, msg_length, probability_table): """ @@ -358,8 +341,8 @@ def bin2float(bin_num): if __name__ == "__main__": coder = CustomArithmeticEncoding({}) - low = 0.25 - high = 0.5 + low = 0.00324 + high = 0.357 # slow_code = coder.encode_binary(low, high) fast_code = coder.custom_binary_encoding(low, high) From ff11c1deb38cb1d582ed1e54ca8ae64145af9492 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 22:21:47 +0100 Subject: [PATCH 05/11] Add testing configs --- README.md | 17 ++++- config/download_datasets.sh | 95 ++++++++++++++++++++++++++ config/generate_csv.sh | 106 +++++++++++++++++++++++++++++ config/local.sh | 27 ++++++++ config/{ => nix}/configuration.nix | 0 config/{ => nix}/flake.lock | 0 config/{ => nix}/flake.nix | 0 config/sub.csv | 5 ++ config/urls.txt | 2 + 9 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 config/download_datasets.sh create mode 100644 config/generate_csv.sh create mode 100644 config/local.sh rename config/{ => nix}/configuration.nix (100%) rename config/{ => nix}/flake.lock (100%) rename config/{ => nix}/flake.nix (100%) create mode 100644 config/sub.csv create mode 100644 config/urls.txt diff --git a/README.md b/README.md index 2b0b5f7..28058f6 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,22 @@ Example usage: ```shell -python main.py --debug train --dataset enwik9 --data-root ~/data/datasets/ml --method optuna --model transformer --model-save-path ~/data/ml-models/test-transformer.pt +# Fetching +python main.py --debug train --method fetch \ + --dataset enwik9 --data-root /path/to/datasets -python benchmark.py --debug train --dataset enwik9 --data-root ~/data/datasets/ml --method optuna --model cnn --model-save-path ~/data/ml-models/test-cnn.pt +# Training +python main.py --debug train --method optuna \ + --dataset enwik9 --data-root /path/to/datasets \ + --model cnn --model-save-path /path/to/optuna-model +python main.py --debug --results /path/to/results train --method full \ + --dataset enwik9 --data-root /path/to/datasets \ + --model-load-path /path/to/optuna-model --model-save-path /path/to/full-model + +# Compressing +python benchmark.py --debug compress \ + --model-load-path /path/to/full-model \ + --input-file inputfile --output-file outputfile ``` ## Running locally diff --git a/config/download_datasets.sh b/config/download_datasets.sh new file mode 100644 index 0000000..d76147d --- /dev/null +++ b/config/download_datasets.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +# Download all URLs (one per line) from a txt file into a destination directory. +# This script is written by Copilot + +set -uo pipefail + +usage() { + echo "Usage: $0 " + echo "Example: $0 urls.txt ~/Downloads/files" + exit 1 +} + +# ---- Args & prerequisites ---- +[[ $# -ne 2 ]] && usage + +URLS_FILE="$1" +DEST_DIR="$2" + +if [[ ! -f "$URLS_FILE" ]]; then + echo "Error: URL list file not found: $URLS_FILE" >&2 + exit 2 +fi + +mkdir -p "$DEST_DIR" || { + echo "Error: Cannot create/access destination directory: $DEST_DIR" >&2 + exit 3 +} + +# Prefer curl if available; otherwise try wget +DOWNLOADER="" +if command -v wget >/dev/null 2>&1; then + DOWNLOADER="wget" +else + echo "Error: Neither 'curl' nor 'wget' found. Please install one." >&2 + exit 4 +fi + +echo "Using downloader: $DOWNLOADER" +echo "Reading URLs from: $URLS_FILE" +echo "Saving to: $DEST_DIR" +echo + +# ---- Download loop ---- +# Reads lines including the last one even if it lacks a trailing newline. +while IFS= read -r url || [[ -n "$url" ]]; do + # Skip empty lines and comments + [[ -z "$url" ]] && continue + [[ "$url" =~ ^[[:space:]]*# ]] && continue + + # Optional: strip leading/trailing whitespace + url="$(printf '%s' "$url" | awk '{$1=$1;print}')" + + # Basic scheme check + if ! [[ "$url" =~ ^https?:// ]]; then + echo "Skipping (invalid URL scheme): $url" >&2 + continue + fi + + echo "→ Downloading: $url" + + if [[ "$DOWNLOADER" == "curl" ]]; then + # -f fail on HTTP errors + # -L follow redirects + # -C - resume if possible + # --retry 3 retry transient failures + # -OJ save using server-provided filename (Content-Disposition) if present + # (cd to dest so curl -O/-OJ writes there) + ( + cd "$DEST_DIR" && \ + curl -fL -C - --retry 3 --remote-header-name -OJ "$url" + ) || { + echo " ⚠️ Failed: $url" >&2 + } + else + # wget: + # --content-disposition: respect server-provided filename + # --tries=3, --timeout=10: retry/transient handling + # --directory-prefix: write to dest + # --no-clobber: skip file if it already exists + wget -q --content-disposition --tries=3 --timeout=10 \ + --directory-prefix="$DEST_DIR" --no-clobber "$url" || { + echo " ⚠️ Failed: $url" >&2 + } + fi + + # Extract .gz files + if [[ "$url" =~ \.gz$ ]]; then + filename="${url##*/}" + echo "Extracting: $filename" + gunzip "$DEST_DIR/${filename}" + fi +done < "$URLS_FILE" + +echo +echo "✅ Done. Files saved in: $DEST_DIR" diff --git a/config/generate_csv.sh b/config/generate_csv.sh new file mode 100644 index 0000000..1d4fae1 --- /dev/null +++ b/config/generate_csv.sh @@ -0,0 +1,106 @@ +#!/usr/bin/env bash +# Generate a CSV that enumerates a test grid for your Python benchmarking script. +# Columns: model,context_size,extra_args +# +# Example: +# ./generate_grid_csv.sh > grid.csv +# ./generate_grid_csv.sh -o grid.csv +# +# You can customize the axes below (MODELS, CONTEXTS, TEMPERATURES, MAX_TOKENS) +# and add common extra args (COMMON_EXTRA). All fields are safely CSV-quoted. + +set -euo pipefail + +OUT_FILE="" +SHOW_HELP=false + +usage() { + cat <<'EOF' +Usage: + generate_grid_csv.sh [-o output.csv] + +Options: + -o Write CSV to this file instead of stdout + -h Show this help + +Customize the axes by editing arrays in the script: + MODELS, CONTEXTS, TEMPERATURES, MAX_TOKENS, COMMON_EXTRA + +Examples: + ./generate_grid_csv.sh > grid.csv + ./generate_grid_csv.sh -o grid.csv + +Tip: + You can also override arrays via env vars (space-separated), e.g.: + MODELS="gpt-4o-mini llama-3.1-8b" CONTEXTS="4096 8192" ./generate_grid_csv.sh > grid.csv +EOF +} + +# --- Parse flags --- +while getopts ":o:h" opt; do + case "$opt" in + o) OUT_FILE="$OPTARG" ;; + h) SHOW_HELP=true ;; + \?) echo "Invalid option: -$OPTARG" >&2; usage; exit 2 ;; + :) echo "Option -$OPTARG requires an argument." >&2; exit 2 ;; + esac +done +shift $((OPTIND - 1)) + +$SHOW_HELP && { usage; exit 0; } + +# --- Axes (edit or override via env) --- +# You can override these by exporting env vars before running, e.g.: +# export MODELS="gpt-4o-mini llama-3.1-8b" +# shellcheck disable=SC2206 +DATASETS=${DATASETS:-"enwik9 human_reference"} +CONTEXTS=${CONTEXTS:-"64"} + +# Convert space-separated env vars to bash arrays +# shellcheck disable=SC2206 +DATASETS_ARR=($DATASETS) +CONTEXTS_ARR=($CONTEXTS) + +# --- CSV helpers --- +csv_escape() { + # Escape double quotes by doubling them, and wrap the whole field in quotes. + local s="$1" + s=${s//\"/\"\"} + printf '%s' "$s" +} + +emit() { + # Write to file or stdout + if [[ -n "$OUT_FILE" ]]; then + printf "%s\n" "$1" >> "$OUT_FILE" + else + printf "%s\n" "$1" + fi +} + +# Prepare output +if [[ -n "$OUT_FILE" ]]; then + : > "$OUT_FILE" # truncate/initialize +fi + +# Header +emit "id,input,model,dataset,context_size" + +# --- Generate rows (Cartesian product) --- +id=0 +model="cnn" +for file in /home/tdpeuter/data/ml-inputs/*; do + for dataset in "${DATASETS_ARR[@]}"; do + for ctx in "${CONTEXTS_ARR[@]}"; do + # CSV-quote each field + row="${id},$(csv_escape "${file}"),$(csv_escape "${model}"),$(csv_escape "${dataset}"),$ctx" + emit "$row" + id=$((id+1)) + done + done +done + +# Done +if [[ -n "$OUT_FILE" ]]; then + echo "CSV written to: $OUT_FILE" +fi diff --git a/config/local.sh b/config/local.sh new file mode 100644 index 0000000..91f79d5 --- /dev/null +++ b/config/local.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +INPUT_FILE="config/sub.csv" + +JOBID="$(date +%s | tail -c 9)" +GIT_HASH="$(git rev-parse --short HEAD)" +DATE="$(date "+%Y%m%d")" +ID="${JOBID}-${GIT_HASH}-${DATE}" +STAT_FILE="results/${ID}/results.csv" +MODELS=/home/tdpeuter/data/ml-models + +while read -r line; do + IFS=',' read -r id input model dataset context <<< "$line" + + if [[ "${id}" == "id" ]]; then + continue + fi + + python main.py compress \ + --model-load-path "${MODELS}/${dataset}/${context}/${model}-1024.pt" \ + --input-file "${input}" \ + --output-file "results/${ID}/${input}.pt" & + exit_code="${?}" + if [ "${exit_code}" -eq 0 ]; then + echo "DONE" + fi +done < "${INPUT_FILE}" diff --git a/config/configuration.nix b/config/nix/configuration.nix similarity index 100% rename from config/configuration.nix rename to config/nix/configuration.nix diff --git a/config/flake.lock b/config/nix/flake.lock similarity index 100% rename from config/flake.lock rename to config/nix/flake.lock diff --git a/config/flake.nix b/config/nix/flake.nix similarity index 100% rename from config/flake.nix rename to config/nix/flake.nix diff --git a/config/sub.csv b/config/sub.csv new file mode 100644 index 0000000..98fdf7a --- /dev/null +++ b/config/sub.csv @@ -0,0 +1,5 @@ +id,input,model,dataset,context_size +0,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,enwik9,64 +1,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,human_reference,64 +2,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,enwik9,64 +3,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,human_reference,64 diff --git a/config/urls.txt b/config/urls.txt new file mode 100644 index 0000000..417b877 --- /dev/null +++ b/config/urls.txt @@ -0,0 +1,2 @@ +https://download.mozilla.org/?product=firefox-latest&os=win&lang=en-US +https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/005/845/GCF_000005845.2_ASM584v2/GCF_000005845.2_ASM584v2_genomic.fna.gz \ No newline at end of file From eec3b2b1e624c7e2f9b8e756ee99f3b1f2a62458 Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Thu, 11 Dec 2025 22:45:28 +0100 Subject: [PATCH 06/11] feat: decompression --- src/process.py | 61 +++++++++++++++++++++++++++++++++++------- src/utils/custom_ae.py | 1 + 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/process.py b/src/process.py index d77f30d..a681960 100644 --- a/src/process.py +++ b/src/process.py @@ -65,18 +65,38 @@ def compress( binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) # Pack - bits = binary_code.split(".", maxsplit=1)[1] - val = int(bits, 2) if len(bits) else 0 - out_bytes = val.to_bytes((len(bits) + 7) // 8, "big") + val = int(binary_code, 2) if len(binary_code) else 0 + out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") if output_file: print(f"Writing to {output_file}") - with open(output_file, "wb") as file: - file.write(out_bytes) + with open(output_file, "w") as file: + file.write(f"{len(byte_data)}\n") + file.write(binary_code) # todo: temporary, decoding depends on binary string else: print(out_bytes) +def bits_to_number(bits: str) -> float: + n = 0 + for i, bit in enumerate(bits, start=1): + n += int(bit) / (1 << i) + return n + + +def make_cumulative(probs): + cumulative = [] + + total = 0 + + for prob in probs: + low = total + high = total + prob + cumulative.append((low, high)) + total = high + return cumulative + + def decompress( device, model_path: str, @@ -84,10 +104,10 @@ def decompress( output_file: str | None = None ): context_length = 128 - output = bytearray() print("Reading in the data") - with open(input_file, "rb") as f: + with open(input_file, "r") as f: + length = int(f.readline()) bytes_data = f.read() if len(bytes_data) == 0: @@ -100,8 +120,29 @@ def decompress( model.eval() print("Decompressing") - ae = ArithmeticEncoding(frequency_table={0: 1}) - stage_min, stage_max = Decimal(0), Decimal(1) - context = deque([0] * context_length, maxlen=context_length) + output = bytearray() + x = bits_to_number(bytes_data) + + for _ in range(length): + probs = model(context) + cumulative = make_cumulative(probs) + + for symbol, (low, high) in enumerate(cumulative): + if low <= x < high: + break + + output.append(symbol) + context.append(chr(symbol)) + + interval_low, interval_high = cumulative[symbol] + interval_width = interval_high - interval_low + x = (x - interval_low) / interval_width + + if output_file is not None: + with open(output_file, "wb") as f: + f.write(output) + return + + print(output.decode('utf-8', errors='replace')) diff --git a/src/utils/custom_ae.py b/src/utils/custom_ae.py index 006e53b..86f3548 100644 --- a/src/utils/custom_ae.py +++ b/src/utils/custom_ae.py @@ -219,6 +219,7 @@ class CustomArithmeticEncoding: return ''.join(map(str, code)) + def decode(self, encoded_msg, msg_length, probability_table): """ Decodes a message from a floating-point number. From 1143acc415cf4f8dd29c3c02ce2c4b8b6f1c9536 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 22:45:46 +0100 Subject: [PATCH 07/11] chore: Replace firefox with 7zip (smaller) --- README.md | 14 +++++++++++--- config/local.sh | 13 ++++++++++++- config/sub.csv | 12 ++++++++---- config/urls.txt | 4 ++-- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 28058f6..e339dbc 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ # neural compression +## Running locally + +``` +uv sync --all-extras +``` + Example usage: ```shell @@ -21,10 +27,12 @@ python benchmark.py --debug compress \ --input-file inputfile --output-file outputfile ``` -## Running locally +Testing compression: -``` -uv sync --all-extras +```shell +bash config/download_datasets.sh config/urls.txt /home/tdpeuter/data/ml-inputs +bash config/generate_csv.sh > config/sub.csv +bash config/local.sh ``` ## Running on the Ghent University HPC diff --git a/config/local.sh b/config/local.sh index 91f79d5..e20ddf7 100644 --- a/config/local.sh +++ b/config/local.sh @@ -9,6 +9,8 @@ ID="${JOBID}-${GIT_HASH}-${DATE}" STAT_FILE="results/${ID}/results.csv" MODELS=/home/tdpeuter/data/ml-models +mkdir -p "results/${ID}" + while read -r line; do IFS=',' read -r id input model dataset context <<< "$line" @@ -16,11 +18,20 @@ while read -r line; do continue fi + output="results/${ID}/$(basename "${input}").${id}.pt" + python main.py compress \ --model-load-path "${MODELS}/${dataset}/${context}/${model}-1024.pt" \ --input-file "${input}" \ - --output-file "results/${ID}/${input}.pt" & + --output-file "${output}" + + in_bytes="$(stat -c %s -- "${input}")" + out_bytes="$(stat -c %s -- "${output}")" + + printf "%d,%s,%s,%s,%d,%d,%d\n" "$id" "$input" "$model" "$dataset" "$context" "$in_bytes" "$out_bytes" >> "${STAT_FILE}" + exit_code="${?}" + if [ "${exit_code}" -eq 0 ]; then echo "DONE" fi diff --git a/config/sub.csv b/config/sub.csv index 98fdf7a..1794775 100644 --- a/config/sub.csv +++ b/config/sub.csv @@ -1,5 +1,9 @@ id,input,model,dataset,context_size -0,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,enwik9,64 -1,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,human_reference,64 -2,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,enwik9,64 -3,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,human_reference,64 +0,/home/tdpeuter/data/ml-inputs/7z2501-x64.exe,cnn,enwik9,64 +1,/home/tdpeuter/data/ml-inputs/7z2501-x64.exe,cnn,human_reference,64 +2,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,enwik9,64 +3,/home/tdpeuter/data/ml-inputs/Firefox Setup 146.0.exe,cnn,human_reference,64 +4,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,enwik9,64 +5,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna,cnn,human_reference,64 +6,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna.gz,cnn,enwik9,64 +7,/home/tdpeuter/data/ml-inputs/GCF_000005845.2_ASM584v2_genomic.fna.gz,cnn,human_reference,64 diff --git a/config/urls.txt b/config/urls.txt index 417b877..eaf8ef9 100644 --- a/config/urls.txt +++ b/config/urls.txt @@ -1,2 +1,2 @@ -https://download.mozilla.org/?product=firefox-latest&os=win&lang=en-US -https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/005/845/GCF_000005845.2_ASM584v2/GCF_000005845.2_ASM584v2_genomic.fna.gz \ No newline at end of file +https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/005/845/GCF_000005845.2_ASM584v2/GCF_000005845.2_ASM584v2_genomic.fna.gz +https://www.7-zip.org/a/7z2501-x64.exe \ No newline at end of file From fc75ab51b0098369a449b8a35744d2414417739a Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 23:23:55 +0100 Subject: [PATCH 08/11] fix: No hardcoding context len --- main.py | 3 ++- src/process.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index ae4e709..7b1aab2 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,8 @@ def main(): compress(device=device, model_path=args.model_load_path, input_file=args.input_file, - output_file=args.output_file + output_file=args.output_file, + context_length=args.context ) case _: diff --git a/src/process.py b/src/process.py index a681960..dc5c479 100644 --- a/src/process.py +++ b/src/process.py @@ -1,20 +1,20 @@ +import contextlib from collections import deque from decimal import Decimal import torch -from pyae import ArithmeticEncoding from tqdm import tqdm +from src.utils import reference_ae + def compress( device, model_path: str, + context_length: int = 128, input_file: str | None = None, output_file: str | None = None ): - # NOTE Hardcoded context length - context_length = 128 - # Get input to compress print("Reading input") if input_file: From 817c16bde4032978f34ce0a320d216ac671ff5c8 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 23:24:11 +0100 Subject: [PATCH 09/11] chore: Add reference AE --- src/utils/reference_ae.py | 601 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100644 src/utils/reference_ae.py diff --git a/src/utils/reference_ae.py b/src/utils/reference_ae.py new file mode 100644 index 0000000..fc2a2e8 --- /dev/null +++ b/src/utils/reference_ae.py @@ -0,0 +1,601 @@ +# +# Reference arithmetic coding +# +# Copyright (c) Project Nayuki +# MIT License. See readme file. +# https://www.nayuki.io/page/reference-arithmetic-coding +# + + +# ---- Arithmetic coding core classes ---- + +# Provides the state and behaviors that arithmetic coding encoders and decoders share. +class ArithmeticCoderBase: + + # Constructs an arithmetic coder, which initializes the code range. + def __init__(self, numbits): + if numbits < 1: + raise ValueError("State size out of range") + + # -- Configuration fields -- + # Number of bits for the 'low' and 'high' state variables. Must be at least 1. + # - Larger values are generally better - they allow a larger maximum frequency total (maximum_total), + # and they reduce the approximation error inherent in adapting fractions to integers; + # both effects reduce the data encoding loss and asymptotically approach the efficiency + # of arithmetic coding using exact fractions. + # - But larger state sizes increase the computation time for integer arithmetic, + # and compression gains beyond ~30 bits essentially zero in real-world applications. + # - Python has native bigint arithmetic, so there is no upper limit to the state size. + # For Java and C++ where using native machine-sized integers makes the most sense, + # they have a recommended value of num_state_bits=32 as the most versatile setting. + self.num_state_bits = numbits + # Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000. + self.full_range = 1 << self.num_state_bits + # The top bit at width num_state_bits, which is 0100...000. + self.half_range = self.full_range >> 1 # Non-zero + # The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1. + self.quarter_range = self.half_range >> 1 # Can be zero + # Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. + self.minimum_range = self.quarter_range + 2 # At least 2 + # Maximum allowed total from a frequency table at all times during coding. This differs from Java + # and C++ because Python's native bigint avoids constraining the size of intermediate computations. + self.maximum_total = self.minimum_range + # Bit mask of num_state_bits ones, which is 0111...111. + self.state_mask = self.full_range - 1 + + # -- State fields -- + # Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. + self.low = 0 + # High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. + self.high = self.state_mask + + + # Updates the code range (low and high) of this arithmetic coder as a result + # of processing the given symbol with the given frequency table. + # Invariants that are true before and after encoding/decoding each symbol + # (letting full_range = 2^num_state_bits): + # - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.) + # Therefore these variables are unsigned integers of num_state_bits bits. + # - low < 1/2 * full_range <= high. + # In other words, they are in different halves of the full range. + # - (low < 1/4 * full_range) || (high >= 3/4 * full_range). + # In other words, they are not both in the middle two quarters. + # - Let range = high - low + 1, then full_range/4 < minimum_range + # <= range <= full_range. These invariants for 'range' essentially + # dictate the maximum total that the incoming frequency table can have. + def update(self, freqs, symbol): + # State check + low = self.low + high = self.high + if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high: + raise AssertionError("Low or high out of range") + range = high - low + 1 + if not (self.minimum_range <= range <= self.full_range): + raise AssertionError("Range out of range") + + # Frequency table values check + total = freqs.get_total() + symlow = freqs.get_low(symbol) + symhigh = freqs.get_high(symbol) + if symlow == symhigh: + raise ValueError("Symbol has zero frequency") + if total > self.maximum_total: + raise ValueError("Cannot code symbol because total is too large") + + # Update range + newlow = low + symlow * range // total + newhigh = low + symhigh * range // total - 1 + self.low = newlow + self.high = newhigh + + # While low and high have the same top bit value, shift them out + while ((self.low ^ self.high) & self.half_range) == 0: + self.shift() + self.low = ((self.low << 1) & self.state_mask) + self.high = ((self.high << 1) & self.state_mask) | 1 + # Now low's top bit must be 0 and high's top bit must be 1 + + # While low's top two bits are 01 and high's are 10, delete the second highest bit of both + while (self.low & ~self.high & self.quarter_range) != 0: + self.underflow() + self.low = (self.low << 1) ^ self.half_range + self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1 + + + # Called to handle the situation when the top bit of 'low' and 'high' are equal. + def shift(self): + raise NotImplementedError() + + + # Called to handle the situation when low=01(...) and high=10(...). + def underflow(self): + raise NotImplementedError() + + + +# Encodes symbols and writes to an arithmetic-coded bit stream. +class ArithmeticEncoder(ArithmeticCoderBase): + + # Constructs an arithmetic coding encoder based on the given bit output stream. + def __init__(self, numbits, bitout): + super(ArithmeticEncoder, self).__init__(numbits) + # The underlying bit output stream. + self.output = bitout + # Number of saved underflow bits. This value can grow without bound. + self.num_underflow = 0 + + + # Encodes the given symbol based on the given frequency table. + # This updates this arithmetic coder's state and may write out some bits. + def write(self, freqs, symbol): + if not isinstance(freqs, CheckedFrequencyTable): + freqs = CheckedFrequencyTable(freqs) + self.update(freqs, symbol) + + + # Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. + # It is important that this method must be called at the end of the each encoding process. + # Note that this method merely writes data to the underlying output stream but does not close it. + def finish(self): + self.output.write(1) + + + def shift(self): + bit = self.low >> (self.num_state_bits - 1) + self.output.write(bit) + + # Write out the saved underflow bits + for _ in range(self.num_underflow): + self.output.write(bit ^ 1) + self.num_underflow = 0 + + + def underflow(self): + self.num_underflow += 1 + + + +# Reads from an arithmetic-coded bit stream and decodes symbols. +class ArithmeticDecoder(ArithmeticCoderBase): + + # Constructs an arithmetic coding decoder based on the + # given bit input stream, and fills the code bits. + def __init__(self, numbits, bitin): + super(ArithmeticDecoder, self).__init__(numbits) + # The underlying bit input stream. + self.input = bitin + # The current raw code bits being buffered, which is always in the range [low, high]. + self.code = 0 + for _ in range(self.num_state_bits): + self.code = self.code << 1 | self.read_code_bit() + + + # Decodes the next symbol based on the given frequency table and returns it. + # Also updates this arithmetic coder's state and may read in some bits. + def read(self, freqs): + if not isinstance(freqs, CheckedFrequencyTable): + freqs = CheckedFrequencyTable(freqs) + + # Translate from coding range scale to frequency table scale + total = freqs.get_total() + if total > self.maximum_total: + raise ValueError("Cannot decode symbol because total is too large") + range = self.high - self.low + 1 + offset = self.code - self.low + value = ((offset + 1) * total - 1) // range + assert value * range // total <= offset + assert 0 <= value < total + + # A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value. + start = 0 + end = freqs.get_symbol_limit() + while end - start > 1: + middle = (start + end) >> 1 + if freqs.get_low(middle) > value: + end = middle + else: + start = middle + assert start + 1 == end + + symbol = start + assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total + self.update(freqs, symbol) + if not (self.low <= self.code <= self.high): + raise AssertionError("Code out of range") + return symbol + + + def shift(self): + self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit() + + + def underflow(self): + self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit() + + + # Returns the next bit (0 or 1) from the input stream. The end + # of stream is treated as an infinite number of trailing zeros. + def read_code_bit(self): + temp = self.input.read() + if temp == -1: + temp = 0 + return temp + + + +# ---- Frequency table classes ---- + +# A table of symbol frequencies. The table holds data for symbols numbered from 0 +# to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer. +# Frequency table objects are primarily used for getting cumulative symbol +# frequencies. These objects can be mutable depending on the implementation. +class FrequencyTable: + + # Returns the number of symbols in this frequency table, which is a positive number. + def get_symbol_limit(self): + raise NotImplementedError() + + # Returns the frequency of the given symbol. The returned value is at least 0. + def get(self, symbol): + raise NotImplementedError() + + # Sets the frequency of the given symbol to the given value. + # The frequency value must be at least 0. + def set(self, symbol, freq): + raise NotImplementedError() + + # Increments the frequency of the given symbol. + def increment(self, symbol): + raise NotImplementedError() + + # Returns the total of all symbol frequencies. The returned value is at + # least 0 and is always equal to get_high(get_symbol_limit() - 1). + def get_total(self): + raise NotImplementedError() + + # Returns the sum of the frequencies of all the symbols strictly + # below the given symbol value. The returned value is at least 0. + def get_low(self, symbol): + raise NotImplementedError() + + # Returns the sum of the frequencies of the given symbol + # and all the symbols below. The returned value is at least 0. + def get_high(self, symbol): + raise NotImplementedError() + + + +# An immutable frequency table where every symbol has the same frequency of 1. +# Useful as a fallback model when no statistics are available. +class FlatFrequencyTable(FrequencyTable): + + # Constructs a flat frequency table with the given number of symbols. + def __init__(self, numsyms): + if numsyms < 1: + raise ValueError("Number of symbols must be positive") + self.numsymbols = numsyms # Total number of symbols, which is at least 1 + + # Returns the number of symbols in this table, which is at least 1. + def get_symbol_limit(self): + return self.numsymbols + + # Returns the frequency of the given symbol, which is always 1. + def get(self, symbol): + self._check_symbol(symbol) + return 1 + + # Returns the total of all symbol frequencies, which is + # always equal to the number of symbols in this table. + def get_total(self): + return self.numsymbols + + # Returns the sum of the frequencies of all the symbols strictly below + # the given symbol value. The returned value is equal to 'symbol'. + def get_low(self, symbol): + self._check_symbol(symbol) + return symbol + + + # Returns the sum of the frequencies of the given symbol and all + # the symbols below. The returned value is equal to 'symbol' + 1. + def get_high(self, symbol): + self._check_symbol(symbol) + return symbol + 1 + + + # Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception. + def _check_symbol(self, symbol): + if not (0 <= symbol < self.numsymbols): + raise ValueError("Symbol out of range") + + # Returns a string representation of this frequency table. The format is subject to change. + def __str__(self): + return "FlatFrequencyTable={}".format(self.numsymbols) + + # Unsupported operation, because this frequency table is immutable. + def set(self, symbol, freq): + raise NotImplementedError() + + # Unsupported operation, because this frequency table is immutable. + def increment(self, symbol): + raise NotImplementedError() + + + +# A mutable table of symbol frequencies. The number of symbols cannot be changed +# after construction. The current algorithm for calculating cumulative frequencies +# takes linear time, but there exist faster algorithms such as Fenwick trees. +class SimpleFrequencyTable(FrequencyTable): + + # Constructs a simple frequency table in one of two ways: + # - SimpleFrequencyTable(sequence): + # Builds a frequency table from the given sequence of symbol frequencies. + # There must be at least 1 symbol, and no symbol has a negative frequency. + # - SimpleFrequencyTable(freqtable): + # Builds a frequency table by copying the given frequency table. + def __init__(self, freqs): + if isinstance(freqs, FrequencyTable): + numsym = freqs.get_symbol_limit() + self.frequencies = [freqs.get(i) for i in range(numsym)] + else: # Assume it is a sequence type + self.frequencies = list(freqs) # Make copy + + # 'frequencies' is a list of the frequency for each symbol. + # Its length is at least 1, and each element is non-negative. + if len(self.frequencies) < 1: + raise ValueError("At least 1 symbol needed") + for freq in self.frequencies: + if freq < 0: + raise ValueError("Negative frequency") + + # Always equal to the sum of 'frequencies' + self.total = sum(self.frequencies) + + # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). + # Initialized lazily. When it is not None, the data is valid. + self.cumulative = None + + + # Returns the number of symbols in this frequency table, which is at least 1. + def get_symbol_limit(self): + return len(self.frequencies) + + + # Returns the frequency of the given symbol. The returned value is at least 0. + def get(self, symbol): + self._check_symbol(symbol) + return self.frequencies[symbol] + + + # Sets the frequency of the given symbol to the given value. The frequency value + # must be at least 0. If an exception is raised, then the state is left unchanged. + def set(self, symbol, freq): + self._check_symbol(symbol) + if freq < 0: + raise ValueError("Negative frequency") + temp = self.total - self.frequencies[symbol] + assert temp >= 0 + self.total = temp + freq + self.frequencies[symbol] = freq + self.cumulative = None + + + # Increments the frequency of the given symbol. + def increment(self, symbol): + self._check_symbol(symbol) + self.total += 1 + self.frequencies[symbol] += 1 + self.cumulative = None + + + # Returns the total of all symbol frequencies. The returned value is at + # least 0 and is always equal to get_high(get_symbol_limit() - 1). + def get_total(self): + return self.total + + + # Returns the sum of the frequencies of all the symbols strictly + # below the given symbol value. The returned value is at least 0. + def get_low(self, symbol): + self._check_symbol(symbol) + if self.cumulative is None: + self._init_cumulative() + return self.cumulative[symbol] + + + # Returns the sum of the frequencies of the given symbol + # and all the symbols below. The returned value is at least 0. + def get_high(self, symbol): + self._check_symbol(symbol) + if self.cumulative is None: + self._init_cumulative() + return self.cumulative[symbol + 1] + + + # Recomputes the array of cumulative symbol frequencies. + def _init_cumulative(self): + cumul = [0] + sum = 0 + for freq in self.frequencies: + sum += freq + cumul.append(sum) + assert sum == self.total + self.cumulative = cumul + + + # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. + def _check_symbol(self, symbol): + if not (0 <= symbol < len(self.frequencies)): + raise ValueError("Symbol out of range") + + + # Returns a string representation of this frequency table, + # useful for debugging only, and the format is subject to change. + def __str__(self): + result = "" + for (i, freq) in enumerate(self.frequencies): + result += "{}\t{}\n".format(i, freq) + return result + + + +# A wrapper that checks the preconditions (arguments) and postconditions (return value) of all +# the frequency table methods. Useful for finding faults in a frequency table implementation. +class CheckedFrequencyTable(FrequencyTable): + + def __init__(self, freqtab): + # The underlying frequency table that holds the data + self.freqtable = freqtab + + + def get_symbol_limit(self): + result = self.freqtable.get_symbol_limit() + if result <= 0: + raise AssertionError("Non-positive symbol limit") + return result + + + def get(self, symbol): + result = self.freqtable.get(symbol) + if not self._is_symbol_in_range(symbol): + raise AssertionError("ValueError expected") + if result < 0: + raise AssertionError("Negative symbol frequency") + return result + + + def get_total(self): + result = self.freqtable.get_total() + if result < 0: + raise AssertionError("Negative total frequency") + return result + + + def get_low(self, symbol): + if self._is_symbol_in_range(symbol): + low = self.freqtable.get_low (symbol) + high = self.freqtable.get_high(symbol) + if not (0 <= low <= high <= self.freqtable.get_total()): + raise AssertionError("Symbol low cumulative frequency out of range") + return low + else: + self.freqtable.get_low(symbol) + raise AssertionError("ValueError expected") + + + def get_high(self, symbol): + if self._is_symbol_in_range(symbol): + low = self.freqtable.get_low (symbol) + high = self.freqtable.get_high(symbol) + if not (0 <= low <= high <= self.freqtable.get_total()): + raise AssertionError("Symbol high cumulative frequency out of range") + return high + else: + self.freqtable.get_high(symbol) + raise AssertionError("ValueError expected") + + + def __str__(self): + return "CheckedFrequencyTable (" + str(self.freqtable) + ")" + + + def set(self, symbol, freq): + self.freqtable.set(symbol, freq) + if not self._is_symbol_in_range(symbol) or freq < 0: + raise AssertionError("ValueError expected") + + + def increment(self, symbol): + self.freqtable.increment(symbol) + if not self._is_symbol_in_range(symbol): + raise AssertionError("ValueError expected") + + + def _is_symbol_in_range(self, symbol): + return 0 <= symbol < self.get_symbol_limit() + + + +# ---- Bit-oriented I/O streams ---- + +# A stream of bits that can be read. Because they come from an underlying byte stream, +# the total number of bits is always a multiple of 8. The bits are read in big endian. +class BitInputStream: + + # Constructs a bit input stream based on the given byte input stream. + def __init__(self, inp): + # The underlying byte stream to read from + self.input = inp + # Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached + self.currentbyte = 0 + # Number of remaining bits in the current byte, always between 0 and 7 (inclusive) + self.numbitsremaining = 0 + + + # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if + # the end of stream is reached. The end of stream always occurs on a byte boundary. + def read(self): + if self.currentbyte == -1: + return -1 + if self.numbitsremaining == 0: + temp = self.input.read(1) + if len(temp) == 0: + self.currentbyte = -1 + return -1 + self.currentbyte = temp[0] + self.numbitsremaining = 8 + assert self.numbitsremaining > 0 + self.numbitsremaining -= 1 + return (self.currentbyte >> self.numbitsremaining) & 1 + + + # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError + # if the end of stream is reached. The end of stream always occurs on a byte boundary. + def read_no_eof(self): + result = self.read() + if result != -1: + return result + else: + raise EOFError() + + + # Closes this stream and the underlying input stream. + def close(self): + self.input.close() + self.currentbyte = -1 + self.numbitsremaining = 0 + + + +# A stream where bits can be written to. Because they are written to an underlying +# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. +# The bits are written in big endian. +class BitOutputStream: + + # Constructs a bit output stream based on the given byte output stream. + def __init__(self, out): + self.output = out # The underlying byte stream to write to + self.currentbyte = 0 # The accumulated bits for the current byte, always in the range [0x00, 0xFF] + self.numbitsfilled = 0 # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive) + + + # Writes a bit to the stream. The given bit must be 0 or 1. + def write(self, b): + if b not in (0, 1): + raise ValueError("Argument must be 0 or 1") + self.currentbyte = (self.currentbyte << 1) | b + self.numbitsfilled += 1 + if self.numbitsfilled == 8: + towrite = bytes((self.currentbyte,)) + self.output.write(towrite) + self.currentbyte = 0 + self.numbitsfilled = 0 + + + # Closes this stream and the underlying output stream. If called when this + # bit stream is not at a byte boundary, then the minimum number of "0" bits + # (between 0 and 7 of them) are written as padding to reach the next byte boundary. + def close(self): + while self.numbitsfilled != 0: + self.write(0) + self.output.close() From 2c0e0c227877038a9a89363e71cb715199ce0041 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 23:24:34 +0100 Subject: [PATCH 10/11] WIP: Attempt at switching --- src/process.py | 66 ++++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/src/process.py b/src/process.py index dc5c479..e59defd 100644 --- a/src/process.py +++ b/src/process.py @@ -36,45 +36,49 @@ def compress( # Init AE print("Initializing AE") - AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used - stage_min, stage_max = Decimal(0), Decimal(1) - stage = None - # Compress - context = deque([0] * context_length, maxlen=context_length) - for byte in tqdm(tensor.tolist(), desc="Compressing"): - context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) + with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout: + enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout) - with torch.inference_mode(): - logits = model(context_tensor) - probabilities = torch.softmax(logits[0], dim=-1) - probabilities = probabilities.detach().cpu().numpy() + context = deque([0] * context_length, maxlen=context_length) - eps = 1e-10 - frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} - probability_table = AE.get_probability_table(frequency_table) + stage_min, stage_max = Decimal(0), Decimal(1) + stage = None - stage = AE.process_stage(probability_table, stage_min, stage_max) - stage_min, stage_max = stage[byte] + # Compress + for byte in tqdm(tensor.tolist(), desc="Compressing"): + context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) - context.append(byte) + with torch.inference_mode(): + logits = model(context_tensor) + probabilities = torch.softmax(logits[0], dim=-1) + probabilities = probabilities.detach().cpu().numpy() - print("Getting encoded value") - interval_min, interval_max, _ = AE.get_encoded_value(stage) - print("Encoding in binary") - binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) + eps = 1e-10 + frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities)) + probability_table = AE.get_probability_table(frequency_table) - # Pack - val = int(binary_code, 2) if len(binary_code) else 0 - out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") + enc.write(frequency_table, byte) - if output_file: - print(f"Writing to {output_file}") - with open(output_file, "w") as file: - file.write(f"{len(byte_data)}\n") - file.write(binary_code) # todo: temporary, decoding depends on binary string - else: - print(out_bytes) + context.append(byte) + + # print("Getting encoded value") + # interval_min, interval_max, _ = AE.get_encoded_value(stage) + # print("Encoding in binary") + # binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) + + # Pack + # val = int(binary_code, 2) if len(binary_code) else 0 + # out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") + + # if output_file: + # print(f"Writing to {output_file}") + # with open(output_file, "w") as file: + # file.write(f"{len(byte_data)}\n") + # file.write(binary_code) # todo: temporary, decoding depends on binary string + # else: + # print(out_bytes) def bits_to_number(bits: str) -> float: From 961d642dd80da563049caf62ec848a2c41cb120e Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Thu, 11 Dec 2025 23:58:54 +0100 Subject: [PATCH 11/11] fixup! WIP: Attempt at switching --- src/process.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/process.py b/src/process.py index e59defd..31de886 100644 --- a/src/process.py +++ b/src/process.py @@ -2,12 +2,26 @@ import contextlib from collections import deque from decimal import Decimal +import numpy as np import torch from tqdm import tqdm from src.utils import reference_ae +def probs_to_freqs(probs, total_freq=8192): + freqs = (probs * total_freq).round().long() + + # Ensure no zero-frequency symbol if needed + freqs[freqs == 0] = 1 + + # Re-normalize so the sum matches total_freq + diff = total_freq - freqs.sum() + freqs[0] += diff # fix the sum by adjusting the first bin + + return freqs + + def compress( device, model_path: str, @@ -51,15 +65,22 @@ def compress( with torch.inference_mode(): logits = model(context_tensor) + #normalize + mean = logits.mean(dim=-1, keepdim=True) + std = logits.std(dim=-1, keepdim=True) + logits = (logits - mean) / (std + 1e-6) + print(f"logits: {logits}") probabilities = torch.softmax(logits[0], dim=-1) - probabilities = probabilities.detach().cpu().numpy() + print(f"probabilities: {probabilities}") + probabilities = probabilities.detach() - eps = 1e-10 - frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} - probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities)) - probability_table = AE.get_probability_table(frequency_table) + eps = 1e-8 + # np.add(probabilities, eps) + # frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities)) + # probability_table = AE.get_probability_table(frequency_table) - enc.write(frequency_table, byte) + enc.write(probability_table, byte) context.append(byte)