feat: measuring code + graph generator code

This commit is contained in:
RobinMeersman 2025-12-15 22:53:32 +01:00
parent dd0b3d3945
commit f3b07c1df3
6 changed files with 325 additions and 140 deletions

View file

@ -58,8 +58,6 @@ class AutoEncoder(Model):
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor:

View file

@ -7,10 +7,13 @@ import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import struct
from src.models import AutoEncoder
from src.utils import reference_ae
NUMBITS = 64
def probs_to_freqs(probs, total_freq=8192):
freqs = (probs * total_freq).round().long()
@ -20,7 +23,7 @@ def probs_to_freqs(probs, total_freq=8192):
# Re-normalize so the sum matches total_freq
diff = total_freq - freqs.sum()
freqs[0] += diff # fix the sum by adjusting the first bin
freqs[freqs.argmax()] += diff # fix the sum by adjusting the first bin
return freqs
@ -32,32 +35,39 @@ def ae_compress(
model: nn.Module,
byte_data: bytes,
tensor: torch.Tensor
):
# Init AE
print("Initializing AE")
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
context = deque([0] * context_length, maxlen=context_length)
with open(output_file, "wb") as raw_out:
# Write original length header (8 bytes)
raw_out.write(struct.pack(">Q", len(byte_data)))
# Compress
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
with contextlib.closing(reference_ae.BitOutputStream(raw_out)) as bitout:
enc = reference_ae.ArithmeticEncoder(NUMBITS, bitout)
with torch.inference_mode():
logits = model(context_tensor)
probabilities = torch.softmax(logits[0], dim=-1)
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
context = deque([0] * context_length, maxlen=context_length)
# write byte to output file
enc.write(probability_table, byte)
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor(
[list(context)],
dtype=torch.long,
device=device
)
context.append(byte)
with torch.inference_mode():
logits = model(context_tensor)
probabilities = torch.softmax(logits[0], dim=-1)
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
freqs = probs_to_freqs(probabilities).tolist()
probability_table = reference_ae.SimpleFrequencyTable(freqs)
enc.write(probability_table, byte)
context.append(byte)
enc.finish()
def chunk_data(x: bytes, context_length=128) -> torch.Tensor:
tensor_data = torch.tensor(list(x), dtype=torch.long)
shape = tensor_data.size(0)
row_count = math.ceil(shape / context_length)
@ -65,13 +75,14 @@ def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
return tensor_data.view(row_count, context_length).float() / 255.0
def auto_encoder_compress(
data: bytes,
model: AutoEncoder,
output_file: str,
output_file: str | None = None,
context_length: int = 128,
device: str = "cuda"
):
) -> torch.Tensor:
# convert data to chunks of context length tensors
# send the data to device
tensor = chunk_data(data, context_length).to(device)
@ -83,10 +94,11 @@ def auto_encoder_compress(
print(f"output shape of compress: {4 * output.shape[0] * output.shape[1]} bytes")
# write output to file
print(f"saving to file {output_file}...")
torch.save(output.detach(), output_file)
if output_file is not None:
print(f"saving to file {output_file}...")
torch.save(output.detach(), output_file)
return output
def compress(
@ -99,7 +111,7 @@ def compress(
):
# Get input to compress
print("Reading input")
if input_file:
if input_file is not None:
with open(input_file, "rb") as file:
byte_data = file.read()
else:
@ -111,14 +123,14 @@ def compress(
tensor = torch.tensor(list(byte_data), dtype=torch.long)
# Get model
print("Loading model")
print(f"Loading model: {model_name}")
model = torch.load(model_path, weights_only=False)
model.to(device)
model.eval()
match model_name:
case "cnn":
ae_compress(
return ae_compress(
output_file,
context_length,
device,
@ -127,7 +139,7 @@ def compress(
tensor
)
case "autoencoder":
auto_encoder_compress(
return auto_encoder_compress(
byte_data,
model,
output_file,
@ -138,16 +150,75 @@ def compress(
raise ValueError(f"Unknown model type: {model_name}")
def ae_decompress(
model: nn.Module,
input_file: str,
context_length=128,
device="cuda",
output_file: str | None = None
):
pass
print("Initializing AE decoder")
with open(input_file, "rb") as raw_in:
# Read original length header
original_length_bytes = raw_in.read(8)
if len(original_length_bytes) != 8:
raise ValueError("Invalid compressed file (missing length header)")
original_length = struct.unpack(">Q", original_length_bytes)[0]
print(f"Original length: {original_length} bytes")
with contextlib.closing(reference_ae.BitInputStream(raw_in)) as bitin:
dec = reference_ae.ArithmeticDecoder(NUMBITS, bitin)
context = deque([0] * context_length, maxlen=context_length)
output_data = []
# Decode exactly original_length bytes
for _ in range(original_length):
context_tensor = torch.tensor(
[list(context)],
dtype=torch.long,
device=device
)
with torch.inference_mode():
logits = model(context_tensor)
probabilities = torch.softmax(logits[0], dim=-1)
freqs = probs_to_freqs(probabilities).tolist()
probability_table = reference_ae.SimpleFrequencyTable(freqs)
byte = dec.read(probability_table)
output_data.append(byte)
context.append(byte)
byte_data = torch.tensor(output_data, dtype=torch.long).byte()
if output_file is not None:
with open(output_file, "wb") as file:
file.write(byte_data.cpu().numpy().tobytes())
return byte_data
def auto_encoder_decompress(
data: torch.Tensor,
model: AutoEncoder,
output_file: str | None = None,
context_length=128,
device="cuda"
) -> torch.Tensor:
decompressed = model.decode(data).squeeze(1)
):
pass
# convert result back to bytes
byte_data = (decompressed * 255.0).round().byte().detach()
if output_file is not None:
with open(output_file, "wb") as file:
file.write(byte_data.cpu().numpy().tobytes())
return byte_data
def decompress(
@ -156,14 +227,16 @@ def decompress(
model_name: str,
input_file: str,
output_file: str | None = None,
context_length: int = 128
context_length: int = 128
):
print("Reading in the data")
with open(input_file, "r") as f:
length = int(f.readline())
bytes_data = f.read()
if model_name != "autoencoder":
with open(input_file, "rb") as f:
data = f.read()
else:
data = torch.load(input_file, map_location=device)
if len(bytes_data) == 0:
if len(data) == 0:
print("Input file is empty, nothing has to be done...")
return
@ -174,8 +247,19 @@ def decompress(
match model_name:
case "cnn":
ae_decompress()
return ae_decompress(
model=model,
input_file=input_file,
context_length=context_length,
output_file=output_file
)
case "autoencoder":
auto_encoder_decompress()
return auto_encoder_decompress(
data,
model,
output_file,
context_length,
device
)
case _:
raise ValueError(f"Unknown model type: {model_name}")