This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/models/priors.py
2025-11-07 12:54:36 +01:00

164 lines
4.7 KiB
Python

"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from utils.distributions import sample_discretized_logistic, \
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
sample_discretized_normal, sample_mixture_normal
from models.utils import Base
from .networks import NN
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
if variable_type == 'discrete':
if distribution_type == 'logistic':
if len(px) == 2:
return sample_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif len(px) == 3:
return sample_mixture_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif distribution_type == 'normal':
return sample_discretized_normal(
*px, inverse_bin_width=inverse_bin_width)
elif variable_type == 'continuous':
if distribution_type == 'logistic':
return sample_logistic(*px)
elif distribution_type == 'normal':
if len(px) == 2:
return sample_normal(*px)
elif len(px) == 3:
return sample_mixture_normal(*px)
elif distribution_type == 'steplogistic':
return sample_logistic(*px)
raise ValueError
class Prior(Base):
def __init__(self, size, args):
super().__init__()
c, h, w = size
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.n_mixtures = args.n_mixtures
if self.n_mixtures == 1:
self.mu = Parameter(torch.Tensor(c, h, w))
self.logs = Parameter(torch.Tensor(c, h, w))
elif self.n_mixtures > 1:
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.reset_parameters()
def reset_parameters(self):
self.mu.data.zero_()
if self.n_mixtures > 1:
self.pi_logit.data.zero_()
for i in range(self.n_mixtures):
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
self.logs.data.zero_()
def get_pz(self, n):
if self.n_mixtures == 1:
mu = self.mu.repeat(n, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
return mu, logs
elif self.n_mixtures > 1:
pi = F.softmax(self.pi_logit, dim=-1)
mu = self.mu.repeat(n, 1, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1, 1)
pi = pi.repeat(n, 1, 1, 1, 1)
return mu, logs, pi
def forward(self, z, ldj):
pz = self.get_pz(z.size(0))
return pz, z, ldj
def sample(self, n):
pz = self.get_pz(n)
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
return z_sample
def decode(self, states, decode_fn):
pz = self.get_pz(n=len(states))
states, z = decode_fn(states, pz)
return states, z
class SplitPrior(Base):
def __init__(self, c_in, factor_out, height, width, args):
super().__init__()
self.split_idx = c_in - factor_out
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.input_channel = c_in
self.nn = NN(
args=args,
c_in=c_in - factor_out,
c_out=factor_out * 2,
height=height,
width=width,
nn_type=args.splitprior_type)
def get_py(self, z):
h = self.nn(z)
mu = h[:, ::2, :, :]
logs = h[:, 1::2, :, :]
py = [mu, logs]
return py
def split(self, z):
z1 = z[:, :self.split_idx, :, :]
y = z[:, self.split_idx:, :, :]
return z1, y
def combine(self, z, y):
result = torch.cat([z, y], dim=1)
return result
def forward(self, z, ldj):
z, y = self.split(z)
py = self.get_py(z)
return py, y, z, ldj
def inverse(self, z, ldj, y):
# Sample if y is not given.
if y is None:
py = self.get_py(z)
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
z = self.combine(z, y)
return z, ldj
def decode(self, z, ldj, states, decode_fn):
py = self.get_py(z)
states, y = decode_fn(states, py)
return self.combine(z, y), ldj, states