132 lines
3.9 KiB
Python
132 lines
3.9 KiB
Python
import numpy as np
|
|
from . import rans
|
|
from utils.distributions import discretized_logistic_cdf, \
|
|
mixture_discretized_logistic_cdf
|
|
import torch
|
|
|
|
precision = 24
|
|
n_bins = 4096
|
|
|
|
|
|
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
|
|
if variable_type == 'discrete':
|
|
if distribution_type == 'logistic':
|
|
if len(pz) == 2:
|
|
return discretized_logistic_cdf(
|
|
z, *pz, inverse_bin_width=inverse_bin_width)
|
|
elif len(pz) == 3:
|
|
return mixture_discretized_logistic_cdf(
|
|
z, *pz, inverse_bin_width=inverse_bin_width)
|
|
elif distribution_type == 'normal':
|
|
pass
|
|
|
|
elif variable_type == 'continuous':
|
|
if distribution_type == 'logistic':
|
|
pass
|
|
elif distribution_type == 'normal':
|
|
pass
|
|
elif distribution_type == 'steplogistic':
|
|
pass
|
|
raise ValueError
|
|
|
|
|
|
def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
|
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
|
|
MEAN = torch.round(mean / bin_width).long()
|
|
|
|
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
|
bin_locations = bin_locations.float() * bin_width
|
|
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
|
|
|
pz = [param[:, :, :, :, None] for param in pz]
|
|
cdf = cdf_fn(
|
|
bin_locations - bin_width,
|
|
pz,
|
|
variable_type,
|
|
distribution_type,
|
|
1./bin_width).cpu().numpy()
|
|
|
|
# Compute CDFs, reweigh to give all bins at least
|
|
# 1 / (2^precision) probability.
|
|
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
|
|
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
|
|
+ np.arange(n_bins)
|
|
|
|
return CDFs, MEAN
|
|
|
|
|
|
def encode_sample(
|
|
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
|
|
if state is None:
|
|
state = rans.x_init
|
|
else:
|
|
state = rans.unflatten(state)
|
|
|
|
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
|
|
|
# z is transformed to Z to match the indices for the CDFs array
|
|
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
|
|
Z = Z.cpu().numpy()
|
|
|
|
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
|
|
print('Z out of allowed range of values, canceling compression')
|
|
return None
|
|
|
|
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
|
|
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
|
|
statfun = statfun_encode(cdf)
|
|
state = rans.append_symbol(statfun, precision)(state, symbol)
|
|
|
|
state = rans.flatten(state)
|
|
|
|
return state
|
|
|
|
|
|
def decode_sample(
|
|
state, pz, variable_type, distribution_type, bin_width=1./256):
|
|
state = rans.unflatten(state)
|
|
|
|
device = pz[0].DEVICE
|
|
size = pz[0].size()[0:4]
|
|
|
|
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
|
|
|
CDFs = CDFs.reshape(-1, n_bins)
|
|
result = np.zeros(len(CDFs), dtype=int)
|
|
for i, cdf in enumerate(CDFs):
|
|
statfun = statfun_decode(cdf)
|
|
state, symbol = rans.pop_symbol(statfun, precision)(state)
|
|
result[i] = symbol
|
|
|
|
Z_flat = torch.from_numpy(result).to(device)
|
|
Z = Z_flat.view(size) - n_bins // 2 + MEAN
|
|
|
|
z = Z.float() * bin_width
|
|
|
|
state = rans.flatten(state)
|
|
|
|
return state, z
|
|
|
|
|
|
def statfun_encode(CDF):
|
|
def _statfun_encode(symbol):
|
|
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
|
|
return _statfun_encode
|
|
|
|
|
|
def statfun_decode(CDF):
|
|
def _statfun_decode(cf):
|
|
# Search such that CDF[s] <= cf < CDF[s]
|
|
s = np.searchsorted(CDF, cf, side='right')
|
|
s = s - 1
|
|
start, freq = statfun_encode(CDF)(s)
|
|
return s, (start, freq)
|
|
return _statfun_decode
|
|
|
|
|
|
def encode(x, symbol):
|
|
return rans.append_symbol(statfun_encode, precision)(x, symbol)
|
|
|
|
|
|
def decode(x):
|
|
return rans.pop_symbol(statfun_decode, precision)(x)
|