feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
0
integer_discrete_flows/coding/__init__.py
Normal file
0
integer_discrete_flows/coding/__init__.py
Normal file
132
integer_discrete_flows/coding/coder.py
Normal file
132
integer_discrete_flows/coding/coder.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import numpy as np
|
||||
from . import rans
|
||||
from utils.distributions import discretized_logistic_cdf, \
|
||||
mixture_discretized_logistic_cdf
|
||||
import torch
|
||||
|
||||
precision = 24
|
||||
n_bins = 4096
|
||||
|
||||
|
||||
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
|
||||
if variable_type == 'discrete':
|
||||
if distribution_type == 'logistic':
|
||||
if len(pz) == 2:
|
||||
return discretized_logistic_cdf(
|
||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||
elif len(pz) == 3:
|
||||
return mixture_discretized_logistic_cdf(
|
||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||
elif distribution_type == 'normal':
|
||||
pass
|
||||
|
||||
elif variable_type == 'continuous':
|
||||
if distribution_type == 'logistic':
|
||||
pass
|
||||
elif distribution_type == 'normal':
|
||||
pass
|
||||
elif distribution_type == 'steplogistic':
|
||||
pass
|
||||
raise ValueError
|
||||
|
||||
|
||||
def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
||||
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
|
||||
MEAN = torch.round(mean / bin_width).long()
|
||||
|
||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||
bin_locations = bin_locations.float() * bin_width
|
||||
bin_locations = bin_locations.to(device=pz[0].device)
|
||||
|
||||
pz = [param[:, :, :, :, None] for param in pz]
|
||||
cdf = cdf_fn(
|
||||
bin_locations - bin_width,
|
||||
pz,
|
||||
variable_type,
|
||||
distribution_type,
|
||||
1./bin_width).cpu().numpy()
|
||||
|
||||
# Compute CDFs, reweigh to give all bins at least
|
||||
# 1 / (2^precision) probability.
|
||||
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
|
||||
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
|
||||
+ np.arange(n_bins)
|
||||
|
||||
return CDFs, MEAN
|
||||
|
||||
|
||||
def encode_sample(
|
||||
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
|
||||
if state is None:
|
||||
state = rans.x_init
|
||||
else:
|
||||
state = rans.unflatten(state)
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
||||
# z is transformed to Z to match the indices for the CDFs array
|
||||
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
|
||||
Z = Z.cpu().numpy()
|
||||
|
||||
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
|
||||
print('Z out of allowed range of values, canceling compression')
|
||||
return None
|
||||
|
||||
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
|
||||
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
|
||||
statfun = statfun_encode(cdf)
|
||||
state = rans.append_symbol(statfun, precision)(state, symbol)
|
||||
|
||||
state = rans.flatten(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def decode_sample(
|
||||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||
state = rans.unflatten(state)
|
||||
|
||||
device = pz[0].device
|
||||
size = pz[0].size()[0:4]
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
||||
CDFs = CDFs.reshape(-1, n_bins)
|
||||
result = np.zeros(len(CDFs), dtype=int)
|
||||
for i, cdf in enumerate(CDFs):
|
||||
statfun = statfun_decode(cdf)
|
||||
state, symbol = rans.pop_symbol(statfun, precision)(state)
|
||||
result[i] = symbol
|
||||
|
||||
Z_flat = torch.from_numpy(result).to(device)
|
||||
Z = Z_flat.view(size) - n_bins // 2 + MEAN
|
||||
|
||||
z = Z.float() * bin_width
|
||||
|
||||
state = rans.flatten(state)
|
||||
|
||||
return state, z
|
||||
|
||||
|
||||
def statfun_encode(CDF):
|
||||
def _statfun_encode(symbol):
|
||||
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
|
||||
return _statfun_encode
|
||||
|
||||
|
||||
def statfun_decode(CDF):
|
||||
def _statfun_decode(cf):
|
||||
# Search such that CDF[s] <= cf < CDF[s]
|
||||
s = np.searchsorted(CDF, cf, side='right')
|
||||
s = s - 1
|
||||
start, freq = statfun_encode(CDF)(s)
|
||||
return s, (start, freq)
|
||||
return _statfun_decode
|
||||
|
||||
|
||||
def encode(x, symbol):
|
||||
return rans.append_symbol(statfun_encode, precision)(x, symbol)
|
||||
|
||||
|
||||
def decode(x):
|
||||
return rans.pop_symbol(statfun_decode, precision)(x)
|
||||
67
integer_discrete_flows/coding/rans.py
Normal file
67
integer_discrete_flows/coding/rans.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h
|
||||
|
||||
ROUGH GUIDE:
|
||||
We use the pythonic names 'append' and 'pop' for encoding and decoding
|
||||
respectively. The compressed state 'x' is an immutable stack, implemented using
|
||||
a cons list.
|
||||
|
||||
x: the current stack-like state of the encoder/decoder.
|
||||
|
||||
precision: the natural numbers are divided into ranges of size 2^precision.
|
||||
|
||||
start & freq: start indicates the beginning of the range in [0, 2^precision-1]
|
||||
that the current symbol is represented by. freq is the length of the range.
|
||||
freq is chosen such that p(symbol) ~= freq/2^precision.
|
||||
"""
|
||||
import numpy as np
|
||||
from functools import reduce
|
||||
|
||||
|
||||
rans_l = 1 << 31 # the lower bound of the normalisation interval
|
||||
tail_bits = (1 << 32) - 1
|
||||
|
||||
x_init = (rans_l, ())
|
||||
|
||||
def append(x, start, freq, precision):
|
||||
"""Encodes a symbol with range [start, start + freq). All frequencies are
|
||||
assumed to sum to "1 << precision", and the resulting bits get written to
|
||||
x."""
|
||||
if x[0] >= ((rans_l >> precision) << 32) * freq:
|
||||
x = (x[0] >> 32, (x[0] & tail_bits, x[1]))
|
||||
return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1]
|
||||
|
||||
def pop(x_, precision):
|
||||
"""Advances in the bit stream by "popping" a single symbol with range start
|
||||
"start" and frequency "freq"."""
|
||||
cf = x_[0] & ((1 << precision) - 1)
|
||||
def pop(start, freq):
|
||||
x = freq * (x_[0] >> precision) + cf - start, x_[1]
|
||||
return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x
|
||||
return cf, pop
|
||||
|
||||
def append_symbol(statfun, precision):
|
||||
def append_(x, symbol):
|
||||
start, freq = statfun(symbol)
|
||||
return append(x, start, freq, precision)
|
||||
return append_
|
||||
|
||||
def pop_symbol(statfun, precision):
|
||||
def pop_(x):
|
||||
cf, pop_fun = pop(x, precision)
|
||||
symbol, (start, freq) = statfun(cf)
|
||||
return pop_fun(start, freq), symbol
|
||||
return pop_
|
||||
|
||||
def flatten(x):
|
||||
"""Flatten a rans state x into a 1d numpy array."""
|
||||
out, x = [x[0] >> 32, x[0]], x[1]
|
||||
while x:
|
||||
x_head, x = x
|
||||
out.append(x_head)
|
||||
return np.asarray(out, dtype=np.uint32)
|
||||
|
||||
def unflatten(arr):
|
||||
"""Unflatten a 1d numpy array into a rans state."""
|
||||
return (int(arr[0]) << 32 | int(arr[1]),
|
||||
reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))
|
||||
Reference in a new issue