164 lines
4.7 KiB
Python
164 lines
4.7 KiB
Python
"""
|
|
Collection of flow strategies
|
|
"""
|
|
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
from utils.distributions import sample_discretized_logistic, \
|
|
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
|
|
sample_discretized_normal, sample_mixture_normal
|
|
from models.utils import Base
|
|
from .networks import NN
|
|
|
|
|
|
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
|
|
if variable_type == 'discrete':
|
|
if distribution_type == 'logistic':
|
|
if len(px) == 2:
|
|
return sample_discretized_logistic(
|
|
*px, inverse_bin_width=inverse_bin_width)
|
|
elif len(px) == 3:
|
|
return sample_mixture_discretized_logistic(
|
|
*px, inverse_bin_width=inverse_bin_width)
|
|
|
|
elif distribution_type == 'normal':
|
|
return sample_discretized_normal(
|
|
*px, inverse_bin_width=inverse_bin_width)
|
|
|
|
elif variable_type == 'continuous':
|
|
if distribution_type == 'logistic':
|
|
return sample_logistic(*px)
|
|
elif distribution_type == 'normal':
|
|
if len(px) == 2:
|
|
return sample_normal(*px)
|
|
elif len(px) == 3:
|
|
return sample_mixture_normal(*px)
|
|
elif distribution_type == 'steplogistic':
|
|
return sample_logistic(*px)
|
|
|
|
raise ValueError
|
|
|
|
|
|
class Prior(Base):
|
|
def __init__(self, size, args):
|
|
super().__init__()
|
|
c, h, w = size
|
|
|
|
self.inverse_bin_width = 2**args.n_bits
|
|
self.variable_type = args.variable_type
|
|
self.distribution_type = args.distribution_type
|
|
self.n_mixtures = args.n_mixtures
|
|
|
|
if self.n_mixtures == 1:
|
|
self.mu = Parameter(torch.Tensor(c, h, w))
|
|
self.logs = Parameter(torch.Tensor(c, h, w))
|
|
elif self.n_mixtures > 1:
|
|
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
self.mu.data.zero_()
|
|
|
|
if self.n_mixtures > 1:
|
|
self.pi_logit.data.zero_()
|
|
for i in range(self.n_mixtures):
|
|
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
|
|
|
|
self.logs.data.zero_()
|
|
|
|
def get_pz(self, n):
|
|
if self.n_mixtures == 1:
|
|
mu = self.mu.repeat(n, 1, 1, 1)
|
|
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
|
|
return mu, logs
|
|
|
|
elif self.n_mixtures > 1:
|
|
pi = F.softmax(self.pi_logit, dim=-1)
|
|
mu = self.mu.repeat(n, 1, 1, 1, 1)
|
|
logs = self.logs.repeat(n, 1, 1, 1, 1)
|
|
pi = pi.repeat(n, 1, 1, 1, 1)
|
|
return mu, logs, pi
|
|
|
|
def forward(self, z, ldj):
|
|
pz = self.get_pz(z.size(0))
|
|
|
|
return pz, z, ldj
|
|
|
|
def sample(self, n):
|
|
pz = self.get_pz(n)
|
|
|
|
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
|
|
|
return z_sample
|
|
|
|
def decode(self, states, decode_fn):
|
|
pz = self.get_pz(n=len(states))
|
|
|
|
states, z = decode_fn(states, pz)
|
|
return states, z
|
|
|
|
|
|
class SplitPrior(Base):
|
|
def __init__(self, c_in, factor_out, height, width, args):
|
|
super().__init__()
|
|
|
|
self.split_idx = c_in - factor_out
|
|
self.inverse_bin_width = 2**args.n_bits
|
|
self.variable_type = args.variable_type
|
|
self.distribution_type = args.distribution_type
|
|
self.input_channel = c_in
|
|
|
|
self.nn = NN(
|
|
args=args,
|
|
c_in=c_in - factor_out,
|
|
c_out=factor_out * 2,
|
|
height=height,
|
|
width=width,
|
|
nn_type=args.splitprior_type)
|
|
|
|
def get_py(self, z):
|
|
h = self.nn(z)
|
|
mu = h[:, ::2, :, :]
|
|
logs = h[:, 1::2, :, :]
|
|
|
|
py = [mu, logs]
|
|
|
|
return py
|
|
|
|
def split(self, z):
|
|
z1 = z[:, :self.split_idx, :, :]
|
|
y = z[:, self.split_idx:, :, :]
|
|
return z1, y
|
|
|
|
def combine(self, z, y):
|
|
result = torch.cat([z, y], dim=1)
|
|
|
|
return result
|
|
|
|
def forward(self, z, ldj):
|
|
z, y = self.split(z)
|
|
|
|
py = self.get_py(z)
|
|
|
|
return py, y, z, ldj
|
|
|
|
def inverse(self, z, ldj, y):
|
|
# Sample if y is not given.
|
|
if y is None:
|
|
py = self.get_py(z)
|
|
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
|
|
|
z = self.combine(z, y)
|
|
|
|
return z, ldj
|
|
|
|
def decode(self, z, ldj, states, decode_fn):
|
|
py = self.get_py(z)
|
|
states, y = decode_fn(states, py)
|
|
return self.combine(z, y), ldj, states
|