diff --git a/src/process.py b/src/process.py index e59defd..31de886 100644 --- a/src/process.py +++ b/src/process.py @@ -2,12 +2,26 @@ import contextlib from collections import deque from decimal import Decimal +import numpy as np import torch from tqdm import tqdm from src.utils import reference_ae +def probs_to_freqs(probs, total_freq=8192): + freqs = (probs * total_freq).round().long() + + # Ensure no zero-frequency symbol if needed + freqs[freqs == 0] = 1 + + # Re-normalize so the sum matches total_freq + diff = total_freq - freqs.sum() + freqs[0] += diff # fix the sum by adjusting the first bin + + return freqs + + def compress( device, model_path: str, @@ -51,15 +65,22 @@ def compress( with torch.inference_mode(): logits = model(context_tensor) + #normalize + mean = logits.mean(dim=-1, keepdim=True) + std = logits.std(dim=-1, keepdim=True) + logits = (logits - mean) / (std + 1e-6) + print(f"logits: {logits}") probabilities = torch.softmax(logits[0], dim=-1) - probabilities = probabilities.detach().cpu().numpy() + print(f"probabilities: {probabilities}") + probabilities = probabilities.detach() - eps = 1e-10 - frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} - probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities)) - probability_table = AE.get_probability_table(frequency_table) + eps = 1e-8 + # np.add(probabilities, eps) + # frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} + probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities)) + # probability_table = AE.get_probability_table(frequency_table) - enc.write(frequency_table, byte) + enc.write(probability_table, byte) context.append(byte)