fixup! WIP: Attempt at switching

This commit is contained in:
Tibo De Peuter 2025-12-11 23:58:54 +01:00
parent 2c0e0c2278
commit 961d642dd8
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2

View file

@ -2,12 +2,26 @@ import contextlib
from collections import deque from collections import deque
from decimal import Decimal from decimal import Decimal
import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from src.utils import reference_ae from src.utils import reference_ae
def probs_to_freqs(probs, total_freq=8192):
freqs = (probs * total_freq).round().long()
# Ensure no zero-frequency symbol if needed
freqs[freqs == 0] = 1
# Re-normalize so the sum matches total_freq
diff = total_freq - freqs.sum()
freqs[0] += diff # fix the sum by adjusting the first bin
return freqs
def compress( def compress(
device, device,
model_path: str, model_path: str,
@ -51,15 +65,22 @@ def compress(
with torch.inference_mode(): with torch.inference_mode():
logits = model(context_tensor) logits = model(context_tensor)
#normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1) probabilities = torch.softmax(logits[0], dim=-1)
probabilities = probabilities.detach().cpu().numpy() print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
eps = 1e-10 eps = 1e-8
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} # np.add(probabilities, eps)
probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities)) # frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
probability_table = AE.get_probability_table(frequency_table) probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# probability_table = AE.get_probability_table(frequency_table)
enc.write(frequency_table, byte) enc.write(probability_table, byte)
context.append(byte) context.append(byte)