fixup! WIP: Attempt at switching
This commit is contained in:
parent
2c0e0c2278
commit
961d642dd8
1 changed files with 27 additions and 6 deletions
|
|
@ -2,12 +2,26 @@ import contextlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.utils import reference_ae
|
from src.utils import reference_ae
|
||||||
|
|
||||||
|
|
||||||
|
def probs_to_freqs(probs, total_freq=8192):
|
||||||
|
freqs = (probs * total_freq).round().long()
|
||||||
|
|
||||||
|
# Ensure no zero-frequency symbol if needed
|
||||||
|
freqs[freqs == 0] = 1
|
||||||
|
|
||||||
|
# Re-normalize so the sum matches total_freq
|
||||||
|
diff = total_freq - freqs.sum()
|
||||||
|
freqs[0] += diff # fix the sum by adjusting the first bin
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
def compress(
|
def compress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
|
@ -51,15 +65,22 @@ def compress(
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
logits = model(context_tensor)
|
logits = model(context_tensor)
|
||||||
|
#normalize
|
||||||
|
mean = logits.mean(dim=-1, keepdim=True)
|
||||||
|
std = logits.std(dim=-1, keepdim=True)
|
||||||
|
logits = (logits - mean) / (std + 1e-6)
|
||||||
|
print(f"logits: {logits}")
|
||||||
probabilities = torch.softmax(logits[0], dim=-1)
|
probabilities = torch.softmax(logits[0], dim=-1)
|
||||||
probabilities = probabilities.detach().cpu().numpy()
|
print(f"probabilities: {probabilities}")
|
||||||
|
probabilities = probabilities.detach()
|
||||||
|
|
||||||
eps = 1e-10
|
eps = 1e-8
|
||||||
frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
# np.add(probabilities, eps)
|
||||||
probability_table = reference_ae.SimpleFrequencyTable([0] * len(probabilities))
|
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||||
probability_table = AE.get_probability_table(frequency_table)
|
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||||
|
# probability_table = AE.get_probability_table(frequency_table)
|
||||||
|
|
||||||
enc.write(frequency_table, byte)
|
enc.write(probability_table, byte)
|
||||||
|
|
||||||
context.append(byte)
|
context.append(byte)
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue