fixup! WIP: Attempt at switching

This commit is contained in:
Tibo De Peuter 2025-12-11 23:58:54 +01:00
parent 2c0e0c2278
commit 961d642dd8
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -2,12 +2,26 @@ import contextlib
from collections import deque
from decimal import Decimal
import numpy as np
import torch
from tqdm import tqdm
from src.utils import reference_ae
def probs_to_freqs(probs, total_freq=8192):
freqs = (probs * total_freq).round().long()
# Ensure no zero-frequency symbol if needed
freqs[freqs == 0] = 1
# Re-normalize so the sum matches total_freq
diff = total_freq - freqs.sum()
freqs[0] += diff # fix the sum by adjusting the first bin
return freqs
def compress(
device,
model_path: str,
@ -51,15 +65,22 @@ def compress(
with torch.inference_mode():
logits = model(context_tensor)
#normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1)
probabilities = probabilities.detach().cpu().numpy()
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
eps = 1e-10
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities))
probability_table = AE.get_probability_table(frequency_table)
eps = 1e-8
# np.add(probabilities, eps)
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# probability_table = AE.get_probability_table(frequency_table)
enc.write(frequency_table, byte)
enc.write(probability_table, byte)
context.append(byte)