fixup! WIP: Attempt at switching
This commit is contained in:
parent
2c0e0c2278
commit
961d642dd8
1 changed files with 27 additions and 6 deletions
|
|
@ -2,12 +2,26 @@ import contextlib
|
|||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.utils import reference_ae
|
||||
|
||||
|
||||
def probs_to_freqs(probs, total_freq=8192):
|
||||
freqs = (probs * total_freq).round().long()
|
||||
|
||||
# Ensure no zero-frequency symbol if needed
|
||||
freqs[freqs == 0] = 1
|
||||
|
||||
# Re-normalize so the sum matches total_freq
|
||||
diff = total_freq - freqs.sum()
|
||||
freqs[0] += diff # fix the sum by adjusting the first bin
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
|
|
@ -51,15 +65,22 @@ def compress(
|
|||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
#normalize
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
std = logits.std(dim=-1, keepdim=True)
|
||||
logits = (logits - mean) / (std + 1e-6)
|
||||
print(f"logits: {logits}")
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
probabilities = probabilities.detach().cpu().numpy()
|
||||
print(f"probabilities: {probabilities}")
|
||||
probabilities = probabilities.detach()
|
||||
|
||||
eps = 1e-10
|
||||
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||
probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities))
|
||||
probability_table = AE.get_probability_table(frequency_table)
|
||||
eps = 1e-8
|
||||
# np.add(probabilities, eps)
|
||||
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||
# probability_table = AE.get_probability_table(frequency_table)
|
||||
|
||||
enc.write(frequency_table, byte)
|
||||
enc.write(probability_table, byte)
|
||||
|
||||
context.append(byte)
|
||||
|
||||
|
|
|
|||
Reference in a new issue