feat: measuring code + graph generator code
This commit is contained in:
parent
dd0b3d3945
commit
f3b07c1df3
6 changed files with 325 additions and 140 deletions
|
|
@ -58,8 +58,6 @@ class AutoEncoder(Model):
|
|||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
return self.decoder(x)
|
||||
|
||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||
|
|
|
|||
164
src/process.py
164
src/process.py
|
|
@ -7,10 +7,13 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import struct
|
||||
|
||||
from src.models import AutoEncoder
|
||||
from src.utils import reference_ae
|
||||
|
||||
NUMBITS = 64
|
||||
|
||||
|
||||
def probs_to_freqs(probs, total_freq=8192):
|
||||
freqs = (probs * total_freq).round().long()
|
||||
|
|
@ -20,7 +23,7 @@ def probs_to_freqs(probs, total_freq=8192):
|
|||
|
||||
# Re-normalize so the sum matches total_freq
|
||||
diff = total_freq - freqs.sum()
|
||||
freqs[0] += diff # fix the sum by adjusting the first bin
|
||||
freqs[freqs.argmax()] += diff # fix the sum by adjusting the first bin
|
||||
|
||||
return freqs
|
||||
|
||||
|
|
@ -32,32 +35,39 @@ def ae_compress(
|
|||
model: nn.Module,
|
||||
byte_data: bytes,
|
||||
tensor: torch.Tensor
|
||||
|
||||
):
|
||||
# Init AE
|
||||
print("Initializing AE")
|
||||
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
with open(output_file, "wb") as raw_out:
|
||||
# Write original length header (8 bytes)
|
||||
raw_out.write(struct.pack(">Q", len(byte_data)))
|
||||
|
||||
# Compress
|
||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||
with contextlib.closing(reference_ae.BitOutputStream(raw_out)) as bitout:
|
||||
enc = reference_ae.ArithmeticEncoder(NUMBITS, bitout)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
print(f"probabilities: {probabilities}")
|
||||
probabilities = probabilities.detach()
|
||||
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
|
||||
# write byte to output file
|
||||
enc.write(probability_table, byte)
|
||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||
context_tensor = torch.tensor(
|
||||
[list(context)],
|
||||
dtype=torch.long,
|
||||
device=device
|
||||
)
|
||||
|
||||
context.append(byte)
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
|
||||
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
|
||||
freqs = probs_to_freqs(probabilities).tolist()
|
||||
probability_table = reference_ae.SimpleFrequencyTable(freqs)
|
||||
|
||||
enc.write(probability_table, byte)
|
||||
context.append(byte)
|
||||
|
||||
enc.finish()
|
||||
|
||||
|
||||
def chunk_data(x: bytes, context_length=128) -> torch.Tensor:
|
||||
tensor_data = torch.tensor(list(x), dtype=torch.long)
|
||||
shape = tensor_data.size(0)
|
||||
row_count = math.ceil(shape / context_length)
|
||||
|
|
@ -65,13 +75,14 @@ def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
|
|||
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
|
||||
return tensor_data.view(row_count, context_length).float() / 255.0
|
||||
|
||||
|
||||
def auto_encoder_compress(
|
||||
data: bytes,
|
||||
model: AutoEncoder,
|
||||
output_file: str,
|
||||
output_file: str | None = None,
|
||||
context_length: int = 128,
|
||||
device: str = "cuda"
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
# convert data to chunks of context length tensors
|
||||
# send the data to device
|
||||
tensor = chunk_data(data, context_length).to(device)
|
||||
|
|
@ -83,10 +94,11 @@ def auto_encoder_compress(
|
|||
print(f"output shape of compress: {4 * output.shape[0] * output.shape[1]} bytes")
|
||||
|
||||
# write output to file
|
||||
print(f"saving to file {output_file}...")
|
||||
torch.save(output.detach(), output_file)
|
||||
|
||||
if output_file is not None:
|
||||
print(f"saving to file {output_file}...")
|
||||
torch.save(output.detach(), output_file)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def compress(
|
||||
|
|
@ -99,7 +111,7 @@ def compress(
|
|||
):
|
||||
# Get input to compress
|
||||
print("Reading input")
|
||||
if input_file:
|
||||
if input_file is not None:
|
||||
with open(input_file, "rb") as file:
|
||||
byte_data = file.read()
|
||||
else:
|
||||
|
|
@ -111,14 +123,14 @@ def compress(
|
|||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
|
||||
# Get model
|
||||
print("Loading model")
|
||||
print(f"Loading model: {model_name}")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
match model_name:
|
||||
case "cnn":
|
||||
ae_compress(
|
||||
return ae_compress(
|
||||
output_file,
|
||||
context_length,
|
||||
device,
|
||||
|
|
@ -127,7 +139,7 @@ def compress(
|
|||
tensor
|
||||
)
|
||||
case "autoencoder":
|
||||
auto_encoder_compress(
|
||||
return auto_encoder_compress(
|
||||
byte_data,
|
||||
model,
|
||||
output_file,
|
||||
|
|
@ -138,16 +150,75 @@ def compress(
|
|||
raise ValueError(f"Unknown model type: {model_name}")
|
||||
|
||||
|
||||
|
||||
def ae_decompress(
|
||||
|
||||
model: nn.Module,
|
||||
input_file: str,
|
||||
context_length=128,
|
||||
device="cuda",
|
||||
output_file: str | None = None
|
||||
):
|
||||
pass
|
||||
print("Initializing AE decoder")
|
||||
|
||||
with open(input_file, "rb") as raw_in:
|
||||
# Read original length header
|
||||
original_length_bytes = raw_in.read(8)
|
||||
if len(original_length_bytes) != 8:
|
||||
raise ValueError("Invalid compressed file (missing length header)")
|
||||
|
||||
original_length = struct.unpack(">Q", original_length_bytes)[0]
|
||||
print(f"Original length: {original_length} bytes")
|
||||
|
||||
with contextlib.closing(reference_ae.BitInputStream(raw_in)) as bitin:
|
||||
dec = reference_ae.ArithmeticDecoder(NUMBITS, bitin)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
output_data = []
|
||||
|
||||
# Decode exactly original_length bytes
|
||||
for _ in range(original_length):
|
||||
context_tensor = torch.tensor(
|
||||
[list(context)],
|
||||
dtype=torch.long,
|
||||
device=device
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
|
||||
freqs = probs_to_freqs(probabilities).tolist()
|
||||
probability_table = reference_ae.SimpleFrequencyTable(freqs)
|
||||
|
||||
byte = dec.read(probability_table)
|
||||
output_data.append(byte)
|
||||
context.append(byte)
|
||||
|
||||
byte_data = torch.tensor(output_data, dtype=torch.long).byte()
|
||||
|
||||
if output_file is not None:
|
||||
with open(output_file, "wb") as file:
|
||||
file.write(byte_data.cpu().numpy().tobytes())
|
||||
|
||||
return byte_data
|
||||
|
||||
|
||||
def auto_encoder_decompress(
|
||||
data: torch.Tensor,
|
||||
model: AutoEncoder,
|
||||
output_file: str | None = None,
|
||||
context_length=128,
|
||||
device="cuda"
|
||||
) -> torch.Tensor:
|
||||
decompressed = model.decode(data).squeeze(1)
|
||||
|
||||
):
|
||||
pass
|
||||
# convert result back to bytes
|
||||
byte_data = (decompressed * 255.0).round().byte().detach()
|
||||
|
||||
if output_file is not None:
|
||||
with open(output_file, "wb") as file:
|
||||
file.write(byte_data.cpu().numpy().tobytes())
|
||||
|
||||
return byte_data
|
||||
|
||||
|
||||
def decompress(
|
||||
|
|
@ -156,14 +227,16 @@ def decompress(
|
|||
model_name: str,
|
||||
input_file: str,
|
||||
output_file: str | None = None,
|
||||
context_length: int = 128
|
||||
context_length: int = 128
|
||||
):
|
||||
print("Reading in the data")
|
||||
with open(input_file, "r") as f:
|
||||
length = int(f.readline())
|
||||
bytes_data = f.read()
|
||||
if model_name != "autoencoder":
|
||||
with open(input_file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
data = torch.load(input_file, map_location=device)
|
||||
|
||||
if len(bytes_data) == 0:
|
||||
if len(data) == 0:
|
||||
print("Input file is empty, nothing has to be done...")
|
||||
return
|
||||
|
||||
|
|
@ -174,8 +247,19 @@ def decompress(
|
|||
|
||||
match model_name:
|
||||
case "cnn":
|
||||
ae_decompress()
|
||||
return ae_decompress(
|
||||
model=model,
|
||||
input_file=input_file,
|
||||
context_length=context_length,
|
||||
output_file=output_file
|
||||
)
|
||||
case "autoencoder":
|
||||
auto_encoder_decompress()
|
||||
return auto_encoder_decompress(
|
||||
data,
|
||||
model,
|
||||
output_file,
|
||||
context_length,
|
||||
device
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type: {model_name}")
|
||||
|
|
|
|||
Reference in a new issue