feat (WIP): Compress
This commit is contained in:
parent
d0457b6571
commit
5c26a52e16
4 changed files with 70 additions and 8 deletions
|
|
@ -9,6 +9,7 @@ dependencies = [
|
||||||
"huggingface_hub==0.27.0",
|
"huggingface_hub==0.27.0",
|
||||||
"fsspec==2024.9.0",
|
"fsspec==2024.9.0",
|
||||||
"lorem>=0.1.1",
|
"lorem>=0.1.1",
|
||||||
|
"arithmeticencodingpython",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
@ -21,3 +22,6 @@ dev = [
|
||||||
"torchdata==0.7.1",
|
"torchdata==0.7.1",
|
||||||
"torchvision==0.24.0",
|
"torchvision==0.24.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
arithmeticencodingpython = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git", rev = "60aad0528c57289218b241d75993574f31b90456" }
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,22 @@
|
||||||
|
from collections import deque
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pyae import ArithmeticEncoding
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def compress(
|
def compress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_file: str,
|
input_file: str | None = None,
|
||||||
input_file: str | None = None
|
output_file: str | None = None
|
||||||
):
|
):
|
||||||
|
# NOTE Hardcoded context length
|
||||||
|
context_length = 128
|
||||||
|
|
||||||
# Get input to compress
|
# Get input to compress
|
||||||
|
print("Reading input")
|
||||||
if input_file:
|
if input_file:
|
||||||
with open(input_file, "rb") as file:
|
with open(input_file, "rb") as file:
|
||||||
byte_data = file.read()
|
byte_data = file.read()
|
||||||
|
|
@ -16,14 +25,56 @@ def compress(
|
||||||
text = input()
|
text = input()
|
||||||
byte_data = text.encode('utf-8', errors='replace')
|
byte_data = text.encode('utf-8', errors='replace')
|
||||||
|
|
||||||
|
print("Converting to tensor")
|
||||||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||||
print(tensor)
|
|
||||||
|
|
||||||
# Get model
|
# Get model
|
||||||
|
print("Loading model")
|
||||||
model = torch.load(model_path, weights_only=False)
|
model = torch.load(model_path, weights_only=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
# TODO Feed to model for compression, store result
|
# Init AE
|
||||||
return
|
print("Initializing AE")
|
||||||
|
AE = ArithmeticEncoding(frequency_table={0: 1}) # These are dummies because they are not used
|
||||||
|
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||||
|
stage = None
|
||||||
|
|
||||||
|
# Compress
|
||||||
|
context = deque([0] * context_length, maxlen=context_length)
|
||||||
|
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||||
|
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
logits = model(context_tensor)
|
||||||
|
probabilities = torch.softmax(logits[0], dim=-1)
|
||||||
|
probabilities = probabilities.detach().cpu().numpy()
|
||||||
|
|
||||||
|
eps = 1e-10
|
||||||
|
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||||
|
probability_table = AE.get_probability_table(frequency_table)
|
||||||
|
|
||||||
|
stage = AE.process_stage(probability_table, stage_min, stage_max)
|
||||||
|
stage_min, stage_max = stage[byte]
|
||||||
|
|
||||||
|
context.append(byte)
|
||||||
|
|
||||||
|
print("Getting encoded value")
|
||||||
|
interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||||
|
print("Encoding in binary")
|
||||||
|
binary_code, _ = AE.encode_binary(interval_min, interval_max)
|
||||||
|
|
||||||
|
# Pack
|
||||||
|
bits = binary_code.split(".", maxsplit=1)[1]
|
||||||
|
val = int(bits, 2) if len(bits) else 0
|
||||||
|
out_bytes = val.to_bytes((len(bits) + 7) // 8, "big")
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
print(f"Writing to {output_file}")
|
||||||
|
with open(output_file, "wb") as file:
|
||||||
|
file.write(out_bytes)
|
||||||
|
else:
|
||||||
|
print(out_bytes)
|
||||||
|
|
||||||
|
|
||||||
def decompress():
|
def decompress():
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ def train(
|
||||||
model_path: str | None = None,
|
model_path: str | None = None,
|
||||||
model_out: str | None = None
|
model_out: str | None = None
|
||||||
):
|
):
|
||||||
batch_size = 2
|
batch_size = 64
|
||||||
|
|
||||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||||
|
|
||||||
|
|
|
||||||
7
uv.lock
generated
7
uv.lock
generated
|
|
@ -163,6 +163,11 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
|
{ url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "arithmeticencodingpython"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = { git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456#60aad0528c57289218b241d75993574f31b90456" }
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "attrs"
|
name = "attrs"
|
||||||
version = "25.4.0"
|
version = "25.4.0"
|
||||||
|
|
@ -1621,6 +1626,7 @@ name = "project-ml"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "arithmeticencodingpython" },
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
{ name = "fsspec" },
|
{ name = "fsspec" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
|
@ -1640,6 +1646,7 @@ dev = [
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "arithmeticencodingpython", git = "https://github.com/ahmedfgad/ArithmeticEncodingPython.git?rev=60aad0528c57289218b241d75993574f31b90456" },
|
||||||
{ name = "datasets", specifier = ">=3.2.0" },
|
{ name = "datasets", specifier = ">=3.2.0" },
|
||||||
{ name = "fsspec", specifier = "==2024.9.0" },
|
{ name = "fsspec", specifier = "==2024.9.0" },
|
||||||
{ name = "huggingface-hub", specifier = "==0.27.0" },
|
{ name = "huggingface-hub", specifier = "==0.27.0" },
|
||||||
|
|
|
||||||
Reference in a new issue