feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
191
integer_discrete_flows/models/Model.py
Normal file
191
integer_discrete_flows/models/Model.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
import models.generative_flows as generative_flows
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .priors import Prior
|
||||
from optimization.loss import compute_loss_array
|
||||
from coding.coder import encode_sample, decode_sample
|
||||
|
||||
|
||||
class Normalize(Base):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.n_bits = args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.input_size = args.input_size
|
||||
|
||||
def forward(self, x, ldj, reverse=False):
|
||||
domain = 2.**self.n_bits
|
||||
|
||||
if self.variable_type == 'discrete':
|
||||
# Discrete variables will be measured on intervals sized 1/domain.
|
||||
# Hence, there is no need to change the log Jacobian determinant.
|
||||
dldj = 0
|
||||
elif self.variable_type == 'continuous':
|
||||
dldj = -np.log(domain) * np.prod(self.input_size)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if not reverse:
|
||||
x = (x - domain / 2) / domain
|
||||
ldj += dldj
|
||||
else:
|
||||
x = x * domain + domain / 2
|
||||
ldj -= dldj
|
||||
|
||||
return x, ldj
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
The base VAE class containing gated convolutional encoder and decoder
|
||||
architecture. Can be used as a base class for VAE's with normalizing flows.
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
|
||||
n_channels, height, width = args.input_size
|
||||
|
||||
self.normalize = Normalize(args)
|
||||
|
||||
self.flow = generative_flows.GenerativeFlow(
|
||||
n_channels, height, width, args)
|
||||
|
||||
self.n_bits = args.n_bits
|
||||
|
||||
self.z_size = self.flow.z_size
|
||||
|
||||
self.prior = Prior(self.z_size, args)
|
||||
|
||||
def dequantize(self, x):
|
||||
if self.training:
|
||||
x = x + torch.rand_like(x)
|
||||
else:
|
||||
# Required for stability.
|
||||
alpha = 1e-3
|
||||
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
|
||||
|
||||
return x
|
||||
|
||||
def loss(self, pz, z, pys, ys, ldj):
|
||||
batchsize = z.size(0)
|
||||
loss, bpd, bpd_per_prior = \
|
||||
compute_loss_array(pz, z, pys, ys, ldj, self.args)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'auxillary_loss'):
|
||||
loss += module.auxillary_loss() / batchsize
|
||||
|
||||
return loss, bpd, bpd_per_prior
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Evaluates the model as a whole, encodes and decodes. Note that the log
|
||||
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
|
||||
"""
|
||||
# Decode z to x.
|
||||
|
||||
assert x.dtype == torch.uint8
|
||||
|
||||
x = x.float()
|
||||
|
||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||
if self.variable_type == 'continuous':
|
||||
x = self.dequantize(x)
|
||||
elif self.variable_type == 'discrete':
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
x, ldj = self.normalize(x, ldj)
|
||||
|
||||
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
|
||||
|
||||
pz, z, ldj = self.prior(z, ldj)
|
||||
|
||||
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
|
||||
|
||||
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
|
||||
|
||||
def inverse(self, z, ys):
|
||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||
x, ldj, pys, py = \
|
||||
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
|
||||
|
||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||
|
||||
x_uint8 = torch.clamp(x, min=0, max=255).to(
|
||||
torch.uint8)
|
||||
|
||||
return x_uint8
|
||||
|
||||
def sample(self, n):
|
||||
z_sample = self.prior.sample(n)
|
||||
|
||||
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
|
||||
x_sample, ldj, pys, py = \
|
||||
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
|
||||
|
||||
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
|
||||
|
||||
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
|
||||
torch.uint8)
|
||||
|
||||
return x_sample_uint8
|
||||
|
||||
def encode(self, x):
|
||||
batchsize = x.size(0)
|
||||
_, _, _, pz, z, pys, ys, _ = self.forward(x)
|
||||
|
||||
pjs = list(pys) + [pz]
|
||||
js = list(ys) + [z]
|
||||
|
||||
states = []
|
||||
|
||||
for b in range(batchsize):
|
||||
state = None
|
||||
for pj, j in zip(pjs, js):
|
||||
pj_b = [param[b:b+1] for param in pj]
|
||||
j_b = j[b:b+1]
|
||||
|
||||
state = encode_sample(
|
||||
j_b, pj_b, self.variable_type,
|
||||
self.distribution_type, state=state)
|
||||
if state is None:
|
||||
break
|
||||
|
||||
states.append(state)
|
||||
|
||||
return states
|
||||
|
||||
def decode(self, states):
|
||||
def decode_fn(states, pj):
|
||||
states = list(states)
|
||||
j = []
|
||||
|
||||
for b in range(len(states)):
|
||||
pj_b = [param[b:b+1] for param in pj]
|
||||
|
||||
states[b], j_b = decode_sample(
|
||||
states[b], pj_b, self.variable_type,
|
||||
self.distribution_type)
|
||||
|
||||
j.append(j_b)
|
||||
|
||||
j = torch.cat(j, dim=0)
|
||||
return states, j
|
||||
|
||||
states, z = self.prior.decode(states, decode_fn=decode_fn)
|
||||
|
||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||
|
||||
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
|
||||
|
||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||
x = x.to(dtype=torch.uint8)
|
||||
|
||||
return x
|
||||
0
integer_discrete_flows/models/__init__.py
Normal file
0
integer_discrete_flows/models/__init__.py
Normal file
151
integer_discrete_flows/models/backround.py
Normal file
151
integer_discrete_flows/models/backround.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.utils import Base
|
||||
|
||||
|
||||
class RoundStraightThrough(torch.autograd.Function):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
rounded = torch.round(input, out=None)
|
||||
return rounded
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = grad_output.clone()
|
||||
return grad_input
|
||||
|
||||
|
||||
_round_straightthrough = RoundStraightThrough().apply
|
||||
|
||||
|
||||
def _stacked_sigmoid(x, temperature, n_approx=3):
|
||||
|
||||
x_ = x - 0.5
|
||||
rounded = torch.round(x_)
|
||||
x_remainder = x_ - rounded
|
||||
|
||||
size = x_.size()
|
||||
x_remainder = x_remainder.view(size + (1,))
|
||||
|
||||
translation = torch.arange(n_approx) - n_approx // 2
|
||||
translation = translation.to(device=x.device, dtype=x.dtype)
|
||||
translation = translation.view([1] * len(size) + [len(translation)])
|
||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||
|
||||
return out + rounded - (n_approx // 2)
|
||||
|
||||
|
||||
class SmoothRound(Base):
|
||||
def __init__(self):
|
||||
self._temperature = None
|
||||
self._n_approx = None
|
||||
super().__init__()
|
||||
self.hard_round = None
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
return self._temperature
|
||||
|
||||
@temperature.setter
|
||||
def temperature(self, value):
|
||||
self._temperature = value
|
||||
|
||||
if self._temperature <= 0.05:
|
||||
self._n_approx = 1
|
||||
elif 0.05 < self._temperature < 0.13:
|
||||
self._n_approx = 3
|
||||
else:
|
||||
self._n_approx = 5
|
||||
|
||||
def forward(self, x):
|
||||
assert self._temperature is not None
|
||||
assert self._n_approx is not None
|
||||
assert self.hard_round is not None
|
||||
|
||||
if self.temperature <= 0.25:
|
||||
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
|
||||
else:
|
||||
h = x
|
||||
|
||||
if self.hard_round:
|
||||
h = _round_straightthrough(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class StochasticRound(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hard_round = None
|
||||
|
||||
def forward(self, x):
|
||||
u = torch.rand_like(x)
|
||||
|
||||
h = x + u - 0.5
|
||||
|
||||
if self.hard_round:
|
||||
h = _round_straightthrough(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class BackRound(Base):
|
||||
|
||||
def __init__(self, args, inverse_bin_width):
|
||||
"""
|
||||
BackRound is an approximation to Round that allows for Backpropagation.
|
||||
|
||||
Approximate the round function using a sum of translated sigmoids.
|
||||
The temperature determines how well the round function is approximated,
|
||||
i.e., a lower temperature corresponds to a better approximation, at
|
||||
the cost of more vanishing gradients.
|
||||
|
||||
BackRound supports the following settings:
|
||||
* By setting hard to True and temperature > 0.25, BackRound
|
||||
reduces to a round function with a straight through gradient
|
||||
estimator
|
||||
* When using 0 < temperature <= 0.25 and hard = True, the
|
||||
output in the forward pass is equivalent to a round function, but the
|
||||
gradient is approximated by the gradient of a sum of sigmoids.
|
||||
* When using hard = False, the output is not constrained to integers.
|
||||
* When temperature > 0.25 and hard = False, BackRound reduces to
|
||||
the identity function.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
temperature: float
|
||||
Temperature used for stacked sigmoid approximated. If temperature
|
||||
is greater than 0.25, the approximation reduces to the indentiy
|
||||
function.
|
||||
hard: bool
|
||||
If hard is True, a (hard) round is applied before returning. The
|
||||
gradient for this is approximated using the straight-through
|
||||
estimator.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inverse_bin_width = inverse_bin_width
|
||||
self.round_approx = args.round_approx
|
||||
|
||||
if args.round_approx == 'smooth':
|
||||
self.round = SmoothRound()
|
||||
elif args.round_approx == 'stochastic':
|
||||
self.round = StochasticRound()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def forward(self, x):
|
||||
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
|
||||
h = x * self.inverse_bin_width
|
||||
|
||||
h = self.round(h)
|
||||
|
||||
return h / self.inverse_bin_width
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
142
integer_discrete_flows/models/coupling.py
Normal file
142
integer_discrete_flows/models/coupling.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .backround import BackRound
|
||||
from .networks import NN
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
class SplitFactorCoupling(Base):
|
||||
def __init__(self, c_in, factor, height, width, args):
|
||||
super().__init__()
|
||||
self.n_channels = args.n_channels
|
||||
self.kernel = 3
|
||||
self.input_channel = c_in
|
||||
self.round_approx = args.round_approx
|
||||
|
||||
if args.variable_type == 'discrete':
|
||||
self.round = BackRound(
|
||||
args, inverse_bin_width=2**args.n_bits)
|
||||
else:
|
||||
self.round = None
|
||||
|
||||
self.split_idx = c_in - (c_in // factor)
|
||||
|
||||
self.nn = NN(
|
||||
args=args,
|
||||
c_in=self.split_idx,
|
||||
c_out=c_in - self.split_idx,
|
||||
height=height,
|
||||
width=width,
|
||||
kernel=self.kernel,
|
||||
nn_type=args.coupling_type)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
z1 = z[:, :self.split_idx, :, :]
|
||||
z2 = z[:, self.split_idx:, :, :]
|
||||
|
||||
t = self.nn(z1)
|
||||
|
||||
if self.round is not None:
|
||||
t = self.round(t)
|
||||
|
||||
if not reverse:
|
||||
z2 = z2 + t
|
||||
else:
|
||||
z2 = z2 - t
|
||||
|
||||
z = torch.cat([z1, z2], dim=1)
|
||||
|
||||
return z, ldj
|
||||
|
||||
|
||||
class Coupling(Base):
|
||||
def __init__(self, c_in, height, width, args):
|
||||
super().__init__()
|
||||
|
||||
if args.split_quarter:
|
||||
factor = 4
|
||||
elif args.splitfactor > 1:
|
||||
factor = args.splitfactor
|
||||
else:
|
||||
factor = 2
|
||||
|
||||
self.coupling = SplitFactorCoupling(
|
||||
c_in, factor, height, width, args=args)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
return self.coupling(z, ldj, reverse)
|
||||
|
||||
|
||||
def test_generative_flow():
|
||||
import models.networks as networks
|
||||
global UNIT_TESTING
|
||||
|
||||
networks.UNIT_TESTING = True
|
||||
UNIT_TESTING = True
|
||||
|
||||
batch_size = 17
|
||||
|
||||
input_size = [12, 16, 16]
|
||||
|
||||
class Args():
|
||||
def __init__(self):
|
||||
self.input_size = input_size
|
||||
self.learn_split = False
|
||||
self.variable_type = 'continuous'
|
||||
self.distribution_type = 'logistic'
|
||||
self.round_approx = 'smooth'
|
||||
self.coupling_type = 'shallow'
|
||||
self.conv_type = 'standard'
|
||||
self.densenet_depth = 8
|
||||
self.bottleneck = False
|
||||
self.n_channels = 512
|
||||
self.network1x1 = 'standard'
|
||||
self.auxilary_freq = -1
|
||||
self.actnorm = False
|
||||
self.LU = False
|
||||
self.coupling_lifting_L = True
|
||||
self.splitprior = True
|
||||
self.split_quarter = True
|
||||
self.n_levels = 2
|
||||
self.n_flows = 2
|
||||
self.cond_L = True
|
||||
self.n_bits = True
|
||||
|
||||
args = Args()
|
||||
|
||||
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
|
||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||
|
||||
model = Coupling(c_in=12, height=16, width=16, args=args)
|
||||
|
||||
print(model)
|
||||
|
||||
model.set_temperature(1.)
|
||||
model.enable_hard_round()
|
||||
|
||||
model.eval()
|
||||
|
||||
z, ldj = model(x, ldj, reverse=False)
|
||||
|
||||
# Check if gradient computation works
|
||||
loss = torch.sum(z**2)
|
||||
loss.backward()
|
||||
|
||||
recon, ldj = model(z, ldj, reverse=True)
|
||||
|
||||
sse = torch.sum(torch.pow(x - recon, 2)).item()
|
||||
ae = torch.abs(x - recon).sum()
|
||||
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_generative_flow()
|
||||
175
integer_discrete_flows/models/generative_flows.py
Normal file
175
integer_discrete_flows/models/generative_flows.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .priors import SplitPrior
|
||||
from .coupling import Coupling
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
def space_to_depth(x):
|
||||
xs = x.size()
|
||||
# Pick off every second element
|
||||
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
||||
# Transpose picked elements next to channels.
|
||||
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
||||
# Combine with channels.
|
||||
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
||||
return x
|
||||
|
||||
|
||||
def depth_to_space(x):
|
||||
xs = x.size()
|
||||
# Pick off elements from channels
|
||||
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
||||
# Transpose picked elements next to HW dimensions.
|
||||
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
||||
# Combine with HW dimensions.
|
||||
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
||||
return x
|
||||
|
||||
|
||||
def int_shape(x):
|
||||
return list(map(int, x.size()))
|
||||
|
||||
|
||||
class Flatten(Base):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class Reshape(Base):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.shape)
|
||||
|
||||
|
||||
class Reverse(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, reverse=False):
|
||||
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
||||
z = z[:, flip_idx, :, :]
|
||||
return z
|
||||
|
||||
|
||||
class Permute(Base):
|
||||
def __init__(self, n_channels):
|
||||
super().__init__()
|
||||
|
||||
permutation = np.arange(n_channels, dtype='int')
|
||||
np.random.shuffle(permutation)
|
||||
|
||||
permutation_inv = np.zeros(n_channels, dtype='int')
|
||||
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
||||
|
||||
self.permutation = torch.from_numpy(permutation)
|
||||
self.permutation_inv = torch.from_numpy(permutation_inv)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = z[:, self.permutation, :, :]
|
||||
else:
|
||||
z = z[:, self.permutation_inv, :, :]
|
||||
|
||||
return z, ldj
|
||||
|
||||
def InversePermute(self):
|
||||
inv_permute = Permute(len(self.permutation))
|
||||
inv_permute.permutation = self.permutation_inv
|
||||
inv_permute.permutation_inv = self.permutation
|
||||
return inv_permute
|
||||
|
||||
|
||||
class Squeeze(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = space_to_depth(z)
|
||||
else:
|
||||
z = depth_to_space(z)
|
||||
return z, ldj
|
||||
|
||||
|
||||
class GenerativeFlow(Base):
|
||||
def __init__(self, n_channels, height, width, args):
|
||||
super().__init__()
|
||||
layers = []
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
for level in range(args.n_levels):
|
||||
|
||||
for i in range(args.n_flows):
|
||||
perm_layer = Permute(n_channels)
|
||||
layers.append(perm_layer)
|
||||
|
||||
layers.append(
|
||||
Coupling(n_channels, height, width, args))
|
||||
|
||||
if level < args.n_levels - 1:
|
||||
if args.splitprior_type != 'none':
|
||||
# Standard splitprior
|
||||
factor_out = n_channels // 2
|
||||
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
||||
n_channels = n_channels - factor_out
|
||||
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
self.z_size = (n_channels, height, width)
|
||||
|
||||
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
||||
if not reverse:
|
||||
for l, layer in enumerate(self.layers):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
py, y, z, ldj = layer(z, ldj)
|
||||
pys += (py,)
|
||||
ys += (y,)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj)
|
||||
|
||||
else:
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
if len(ys) > 0:
|
||||
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
||||
# Pop last element
|
||||
ys = ys[:-1]
|
||||
else:
|
||||
z, ldj = layer.inverse(z, ldj, y=None)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj, pys, ys
|
||||
|
||||
def decode(self, z, ldj, state, decode_fn):
|
||||
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, SplitPrior):
|
||||
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj
|
||||
154
integer_discrete_flows/models/networks.py
Normal file
154
integer_discrete_flows/models/networks.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.utils import Base
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
class Conv2dReLU(Base):
|
||||
def __init__(
|
||||
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
|
||||
y = F.relu(h)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ResidualBlock(Base):
|
||||
def __init__(self, n_channels, kernel, Conv2dAct):
|
||||
super().__init__()
|
||||
|
||||
self.nn = torch.nn.Sequential(
|
||||
Conv2dAct(n_channels, n_channels, kernel, padding=1),
|
||||
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
h = F.relu(h + x)
|
||||
return h
|
||||
|
||||
|
||||
class DenseLayer(Base):
|
||||
def __init__(self, args, n_inputs, growth, Conv2dAct):
|
||||
super().__init__()
|
||||
|
||||
conv1x1 = Conv2dAct(
|
||||
n_inputs, n_inputs, kernel_size=1, stride=1,
|
||||
padding=0, bias=True)
|
||||
|
||||
self.nn = torch.nn.Sequential(
|
||||
conv1x1,
|
||||
Conv2dAct(
|
||||
n_inputs, growth, kernel_size=3, stride=1,
|
||||
padding=1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
|
||||
h = torch.cat([x, h], dim=1)
|
||||
return h
|
||||
|
||||
|
||||
class DenseBlock(Base):
|
||||
def __init__(
|
||||
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
|
||||
super().__init__()
|
||||
depth = args.densenet_depth
|
||||
|
||||
future_growth = n_outputs - n_inputs
|
||||
|
||||
layers = []
|
||||
|
||||
for d in range(depth):
|
||||
growth = future_growth // (depth - d)
|
||||
|
||||
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
|
||||
n_inputs += growth
|
||||
future_growth -= growth
|
||||
|
||||
self.nn = torch.nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.nn(x)
|
||||
|
||||
|
||||
class Identity(Base):
|
||||
def __init__(self):
|
||||
super.__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class NN(Base):
|
||||
def __init__(
|
||||
self, args, c_in, c_out, height, width, nn_type, kernel=3):
|
||||
super().__init__()
|
||||
|
||||
Conv2dAct = Conv2dReLU
|
||||
n_channels = args.n_channels
|
||||
|
||||
if nn_type == 'shallow':
|
||||
|
||||
if args.network1x1 == 'standard':
|
||||
conv1x1 = Conv2dAct(
|
||||
n_channels, n_channels, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
|
||||
layers = [
|
||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||
conv1x1]
|
||||
|
||||
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
|
||||
|
||||
elif nn_type == 'resnet':
|
||||
layers = [
|
||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||
ResidualBlock(n_channels, kernel, Conv2dAct),
|
||||
ResidualBlock(n_channels, kernel, Conv2dAct)]
|
||||
|
||||
layers += [
|
||||
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
|
||||
]
|
||||
|
||||
elif nn_type == 'densenet':
|
||||
layers = [
|
||||
DenseBlock(
|
||||
args=args,
|
||||
n_inputs=c_in,
|
||||
n_outputs=n_channels + c_in,
|
||||
kernel=kernel,
|
||||
Conv2dAct=Conv2dAct)]
|
||||
|
||||
layers += [
|
||||
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
|
||||
]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.nn = torch.nn.Sequential(*layers)
|
||||
|
||||
# Set parameters of last conv-layer to zero.
|
||||
if not UNIT_TESTING:
|
||||
self.nn[-1].weight.data.zero_()
|
||||
self.nn[-1].bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
return self.nn(x)
|
||||
164
integer_discrete_flows/models/priors.py
Normal file
164
integer_discrete_flows/models/priors.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
from utils.distributions import sample_discretized_logistic, \
|
||||
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
|
||||
sample_discretized_normal, sample_mixture_normal
|
||||
from models.utils import Base
|
||||
from .networks import NN
|
||||
|
||||
|
||||
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
|
||||
if variable_type == 'discrete':
|
||||
if distribution_type == 'logistic':
|
||||
if len(px) == 2:
|
||||
return sample_discretized_logistic(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
elif len(px) == 3:
|
||||
return sample_mixture_discretized_logistic(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
elif distribution_type == 'normal':
|
||||
return sample_discretized_normal(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
elif variable_type == 'continuous':
|
||||
if distribution_type == 'logistic':
|
||||
return sample_logistic(*px)
|
||||
elif distribution_type == 'normal':
|
||||
if len(px) == 2:
|
||||
return sample_normal(*px)
|
||||
elif len(px) == 3:
|
||||
return sample_mixture_normal(*px)
|
||||
elif distribution_type == 'steplogistic':
|
||||
return sample_logistic(*px)
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
class Prior(Base):
|
||||
def __init__(self, size, args):
|
||||
super().__init__()
|
||||
c, h, w = size
|
||||
|
||||
self.inverse_bin_width = 2**args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
self.n_mixtures = args.n_mixtures
|
||||
|
||||
if self.n_mixtures == 1:
|
||||
self.mu = Parameter(torch.Tensor(c, h, w))
|
||||
self.logs = Parameter(torch.Tensor(c, h, w))
|
||||
elif self.n_mixtures > 1:
|
||||
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.mu.data.zero_()
|
||||
|
||||
if self.n_mixtures > 1:
|
||||
self.pi_logit.data.zero_()
|
||||
for i in range(self.n_mixtures):
|
||||
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
|
||||
|
||||
self.logs.data.zero_()
|
||||
|
||||
def get_pz(self, n):
|
||||
if self.n_mixtures == 1:
|
||||
mu = self.mu.repeat(n, 1, 1, 1)
|
||||
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
|
||||
return mu, logs
|
||||
|
||||
elif self.n_mixtures > 1:
|
||||
pi = F.softmax(self.pi_logit, dim=-1)
|
||||
mu = self.mu.repeat(n, 1, 1, 1, 1)
|
||||
logs = self.logs.repeat(n, 1, 1, 1, 1)
|
||||
pi = pi.repeat(n, 1, 1, 1, 1)
|
||||
return mu, logs, pi
|
||||
|
||||
def forward(self, z, ldj):
|
||||
pz = self.get_pz(z.size(0))
|
||||
|
||||
return pz, z, ldj
|
||||
|
||||
def sample(self, n):
|
||||
pz = self.get_pz(n)
|
||||
|
||||
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||
|
||||
return z_sample
|
||||
|
||||
def decode(self, states, decode_fn):
|
||||
pz = self.get_pz(n=len(states))
|
||||
|
||||
states, z = decode_fn(states, pz)
|
||||
return states, z
|
||||
|
||||
|
||||
class SplitPrior(Base):
|
||||
def __init__(self, c_in, factor_out, height, width, args):
|
||||
super().__init__()
|
||||
|
||||
self.split_idx = c_in - factor_out
|
||||
self.inverse_bin_width = 2**args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
self.input_channel = c_in
|
||||
|
||||
self.nn = NN(
|
||||
args=args,
|
||||
c_in=c_in - factor_out,
|
||||
c_out=factor_out * 2,
|
||||
height=height,
|
||||
width=width,
|
||||
nn_type=args.splitprior_type)
|
||||
|
||||
def get_py(self, z):
|
||||
h = self.nn(z)
|
||||
mu = h[:, ::2, :, :]
|
||||
logs = h[:, 1::2, :, :]
|
||||
|
||||
py = [mu, logs]
|
||||
|
||||
return py
|
||||
|
||||
def split(self, z):
|
||||
z1 = z[:, :self.split_idx, :, :]
|
||||
y = z[:, self.split_idx:, :, :]
|
||||
return z1, y
|
||||
|
||||
def combine(self, z, y):
|
||||
result = torch.cat([z, y], dim=1)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, z, ldj):
|
||||
z, y = self.split(z)
|
||||
|
||||
py = self.get_py(z)
|
||||
|
||||
return py, y, z, ldj
|
||||
|
||||
def inverse(self, z, ldj, y):
|
||||
# Sample if y is not given.
|
||||
if y is None:
|
||||
py = self.get_py(z)
|
||||
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||
|
||||
z = self.combine(z, y)
|
||||
|
||||
return z, ldj
|
||||
|
||||
def decode(self, z, ldj, states, decode_fn):
|
||||
py = self.get_py(z)
|
||||
states, y = decode_fn(states, py)
|
||||
return self.combine(z, y), ldj, states
|
||||
36
integer_discrete_flows/models/utils.py
Normal file
36
integer_discrete_flows/models/utils.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
|
||||
|
||||
class Base(torch.nn.Module):
|
||||
"""
|
||||
The base class for modules. That contains a disable round mode
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _set_child_attribute(self, attr, value):
|
||||
r"""Sets the module in rounding mode.
|
||||
|
||||
This has any effect only on certain modules if variable type is
|
||||
discrete.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
if hasattr(self, attr):
|
||||
setattr(self, attr, value)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, attr):
|
||||
setattr(module, attr, value)
|
||||
return self
|
||||
|
||||
def set_temperature(self, value):
|
||||
self._set_child_attribute("temperature", value)
|
||||
|
||||
def enable_hard_round(self, mode=True):
|
||||
self._set_child_attribute("hard_round", mode)
|
||||
|
||||
def disable_hard_round(self, mode=True):
|
||||
self.enable_hard_round(not mode)
|
||||
Reference in a new issue