feat: initial for IDF

This commit is contained in:
Robin Meersman 2025-11-07 12:54:36 +01:00
commit ef4684ef39
27 changed files with 2830 additions and 0 deletions

View file

@ -0,0 +1,191 @@
import torch
import models.generative_flows as generative_flows
import numpy as np
from models.utils import Base
from .priors import Prior
from optimization.loss import compute_loss_array
from coding.coder import encode_sample, decode_sample
class Normalize(Base):
def __init__(self, args):
super().__init__()
self.n_bits = args.n_bits
self.variable_type = args.variable_type
self.input_size = args.input_size
def forward(self, x, ldj, reverse=False):
domain = 2.**self.n_bits
if self.variable_type == 'discrete':
# Discrete variables will be measured on intervals sized 1/domain.
# Hence, there is no need to change the log Jacobian determinant.
dldj = 0
elif self.variable_type == 'continuous':
dldj = -np.log(domain) * np.prod(self.input_size)
else:
raise ValueError
if not reverse:
x = (x - domain / 2) / domain
ldj += dldj
else:
x = x * domain + domain / 2
ldj -= dldj
return x, ldj
class Model(Base):
"""
The base VAE class containing gated convolutional encoder and decoder
architecture. Can be used as a base class for VAE's with normalizing flows.
"""
def __init__(self, args):
super().__init__()
self.args = args
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
n_channels, height, width = args.input_size
self.normalize = Normalize(args)
self.flow = generative_flows.GenerativeFlow(
n_channels, height, width, args)
self.n_bits = args.n_bits
self.z_size = self.flow.z_size
self.prior = Prior(self.z_size, args)
def dequantize(self, x):
if self.training:
x = x + torch.rand_like(x)
else:
# Required for stability.
alpha = 1e-3
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
return x
def loss(self, pz, z, pys, ys, ldj):
batchsize = z.size(0)
loss, bpd, bpd_per_prior = \
compute_loss_array(pz, z, pys, ys, ldj, self.args)
for module in self.modules():
if hasattr(module, 'auxillary_loss'):
loss += module.auxillary_loss() / batchsize
return loss, bpd, bpd_per_prior
def forward(self, x):
"""
Evaluates the model as a whole, encodes and decodes. Note that the log
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
"""
# Decode z to x.
assert x.dtype == torch.uint8
x = x.float()
ldj = torch.zeros_like(x[:, 0, 0, 0])
if self.variable_type == 'continuous':
x = self.dequantize(x)
elif self.variable_type == 'discrete':
pass
else:
raise ValueError
x, ldj = self.normalize(x, ldj)
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
pz, z, ldj = self.prior(z, ldj)
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
def inverse(self, z, ys):
ldj = torch.zeros_like(z[:, 0, 0, 0])
x, ldj, pys, py = \
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
x, ldj = self.normalize(x, ldj, reverse=True)
x_uint8 = torch.clamp(x, min=0, max=255).to(
torch.uint8)
return x_uint8
def sample(self, n):
z_sample = self.prior.sample(n)
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
x_sample, ldj, pys, py = \
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
torch.uint8)
return x_sample_uint8
def encode(self, x):
batchsize = x.size(0)
_, _, _, pz, z, pys, ys, _ = self.forward(x)
pjs = list(pys) + [pz]
js = list(ys) + [z]
states = []
for b in range(batchsize):
state = None
for pj, j in zip(pjs, js):
pj_b = [param[b:b+1] for param in pj]
j_b = j[b:b+1]
state = encode_sample(
j_b, pj_b, self.variable_type,
self.distribution_type, state=state)
if state is None:
break
states.append(state)
return states
def decode(self, states):
def decode_fn(states, pj):
states = list(states)
j = []
for b in range(len(states)):
pj_b = [param[b:b+1] for param in pj]
states[b], j_b = decode_sample(
states[b], pj_b, self.variable_type,
self.distribution_type)
j.append(j_b)
j = torch.cat(j, dim=0)
return states, j
states, z = self.prior.decode(states, decode_fn=decode_fn)
ldj = torch.zeros_like(z[:, 0, 0, 0])
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
x, ldj = self.normalize(x, ldj, reverse=True)
x = x.to(dtype=torch.uint8)
return x

View file

@ -0,0 +1,151 @@
import torch
import torch.nn.functional as F
import numpy as np
from models.utils import Base
class RoundStraightThrough(torch.autograd.Function):
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, input):
rounded = torch.round(input, out=None)
return rounded
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
_round_straightthrough = RoundStraightThrough().apply
def _stacked_sigmoid(x, temperature, n_approx=3):
x_ = x - 0.5
rounded = torch.round(x_)
x_remainder = x_ - rounded
size = x_.size()
x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.device, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
return out + rounded - (n_approx // 2)
class SmoothRound(Base):
def __init__(self):
self._temperature = None
self._n_approx = None
super().__init__()
self.hard_round = None
@property
def temperature(self):
return self._temperature
@temperature.setter
def temperature(self, value):
self._temperature = value
if self._temperature <= 0.05:
self._n_approx = 1
elif 0.05 < self._temperature < 0.13:
self._n_approx = 3
else:
self._n_approx = 5
def forward(self, x):
assert self._temperature is not None
assert self._n_approx is not None
assert self.hard_round is not None
if self.temperature <= 0.25:
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
else:
h = x
if self.hard_round:
h = _round_straightthrough(h)
return h
class StochasticRound(Base):
def __init__(self):
super().__init__()
self.hard_round = None
def forward(self, x):
u = torch.rand_like(x)
h = x + u - 0.5
if self.hard_round:
h = _round_straightthrough(h)
return h
class BackRound(Base):
def __init__(self, args, inverse_bin_width):
"""
BackRound is an approximation to Round that allows for Backpropagation.
Approximate the round function using a sum of translated sigmoids.
The temperature determines how well the round function is approximated,
i.e., a lower temperature corresponds to a better approximation, at
the cost of more vanishing gradients.
BackRound supports the following settings:
* By setting hard to True and temperature > 0.25, BackRound
reduces to a round function with a straight through gradient
estimator
* When using 0 < temperature <= 0.25 and hard = True, the
output in the forward pass is equivalent to a round function, but the
gradient is approximated by the gradient of a sum of sigmoids.
* When using hard = False, the output is not constrained to integers.
* When temperature > 0.25 and hard = False, BackRound reduces to
the identity function.
Arguments
---------
temperature: float
Temperature used for stacked sigmoid approximated. If temperature
is greater than 0.25, the approximation reduces to the indentiy
function.
hard: bool
If hard is True, a (hard) round is applied before returning. The
gradient for this is approximated using the straight-through
estimator.
"""
super().__init__()
self.inverse_bin_width = inverse_bin_width
self.round_approx = args.round_approx
if args.round_approx == 'smooth':
self.round = SmoothRound()
elif args.round_approx == 'stochastic':
self.round = StochasticRound()
else:
raise ValueError
def forward(self, x):
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
h = x * self.inverse_bin_width
h = self.round(h)
return h / self.inverse_bin_width
else:
raise ValueError

View file

@ -0,0 +1,142 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import numpy as np
from models.utils import Base
from .backround import BackRound
from .networks import NN
UNIT_TESTING = False
class SplitFactorCoupling(Base):
def __init__(self, c_in, factor, height, width, args):
super().__init__()
self.n_channels = args.n_channels
self.kernel = 3
self.input_channel = c_in
self.round_approx = args.round_approx
if args.variable_type == 'discrete':
self.round = BackRound(
args, inverse_bin_width=2**args.n_bits)
else:
self.round = None
self.split_idx = c_in - (c_in // factor)
self.nn = NN(
args=args,
c_in=self.split_idx,
c_out=c_in - self.split_idx,
height=height,
width=width,
kernel=self.kernel,
nn_type=args.coupling_type)
def forward(self, z, ldj, reverse=False):
z1 = z[:, :self.split_idx, :, :]
z2 = z[:, self.split_idx:, :, :]
t = self.nn(z1)
if self.round is not None:
t = self.round(t)
if not reverse:
z2 = z2 + t
else:
z2 = z2 - t
z = torch.cat([z1, z2], dim=1)
return z, ldj
class Coupling(Base):
def __init__(self, c_in, height, width, args):
super().__init__()
if args.split_quarter:
factor = 4
elif args.splitfactor > 1:
factor = args.splitfactor
else:
factor = 2
self.coupling = SplitFactorCoupling(
c_in, factor, height, width, args=args)
def forward(self, z, ldj, reverse=False):
return self.coupling(z, ldj, reverse)
def test_generative_flow():
import models.networks as networks
global UNIT_TESTING
networks.UNIT_TESTING = True
UNIT_TESTING = True
batch_size = 17
input_size = [12, 16, 16]
class Args():
def __init__(self):
self.input_size = input_size
self.learn_split = False
self.variable_type = 'continuous'
self.distribution_type = 'logistic'
self.round_approx = 'smooth'
self.coupling_type = 'shallow'
self.conv_type = 'standard'
self.densenet_depth = 8
self.bottleneck = False
self.n_channels = 512
self.network1x1 = 'standard'
self.auxilary_freq = -1
self.actnorm = False
self.LU = False
self.coupling_lifting_L = True
self.splitprior = True
self.split_quarter = True
self.n_levels = 2
self.n_flows = 2
self.cond_L = True
self.n_bits = True
args = Args()
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
ldj = torch.zeros_like(x[:, 0, 0, 0])
model = Coupling(c_in=12, height=16, width=16, args=args)
print(model)
model.set_temperature(1.)
model.enable_hard_round()
model.eval()
z, ldj = model(x, ldj, reverse=False)
# Check if gradient computation works
loss = torch.sum(z**2)
loss.backward()
recon, ldj = model(z, ldj, reverse=True)
sse = torch.sum(torch.pow(x - recon, 2)).item()
ae = torch.abs(x - recon).sum()
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
if __name__ == '__main__':
test_generative_flow()

View file

@ -0,0 +1,175 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import numpy as np
from models.utils import Base
from .priors import SplitPrior
from .coupling import Coupling
UNIT_TESTING = False
def space_to_depth(x):
xs = x.size()
# Pick off every second element
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
# Transpose picked elements next to channels.
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
# Combine with channels.
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
return x
def depth_to_space(x):
xs = x.size()
# Pick off elements from channels
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
# Transpose picked elements next to HW dimensions.
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
# Combine with HW dimensions.
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
return x
def int_shape(x):
return list(map(int, x.size()))
class Flatten(Base):
def forward(self, x):
return x.view(x.size(0), -1)
class Reshape(Base):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(x.size(0), *self.shape)
class Reverse(Base):
def __init__(self):
super().__init__()
def forward(self, z, reverse=False):
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
z = z[:, flip_idx, :, :]
return z
class Permute(Base):
def __init__(self, n_channels):
super().__init__()
permutation = np.arange(n_channels, dtype='int')
np.random.shuffle(permutation)
permutation_inv = np.zeros(n_channels, dtype='int')
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
self.permutation = torch.from_numpy(permutation)
self.permutation_inv = torch.from_numpy(permutation_inv)
def forward(self, z, ldj, reverse=False):
if not reverse:
z = z[:, self.permutation, :, :]
else:
z = z[:, self.permutation_inv, :, :]
return z, ldj
def InversePermute(self):
inv_permute = Permute(len(self.permutation))
inv_permute.permutation = self.permutation_inv
inv_permute.permutation_inv = self.permutation
return inv_permute
class Squeeze(Base):
def __init__(self):
super().__init__()
def forward(self, z, ldj, reverse=False):
if not reverse:
z = space_to_depth(z)
else:
z = depth_to_space(z)
return z, ldj
class GenerativeFlow(Base):
def __init__(self, n_channels, height, width, args):
super().__init__()
layers = []
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
for level in range(args.n_levels):
for i in range(args.n_flows):
perm_layer = Permute(n_channels)
layers.append(perm_layer)
layers.append(
Coupling(n_channels, height, width, args))
if level < args.n_levels - 1:
if args.splitprior_type != 'none':
# Standard splitprior
factor_out = n_channels // 2
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
n_channels = n_channels - factor_out
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
self.layers = torch.nn.ModuleList(layers)
self.z_size = (n_channels, height, width)
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
if not reverse:
for l, layer in enumerate(self.layers):
if isinstance(layer, (SplitPrior)):
py, y, z, ldj = layer(z, ldj)
pys += (py,)
ys += (y,)
else:
z, ldj = layer(z, ldj)
else:
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, (SplitPrior)):
if len(ys) > 0:
z, ldj = layer.inverse(z, ldj, y=ys[-1])
# Pop last element
ys = ys[:-1]
else:
z, ldj = layer.inverse(z, ldj, y=None)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj, pys, ys
def decode(self, z, ldj, state, decode_fn):
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, SplitPrior):
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj

View file

@ -0,0 +1,154 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.utils import Base
UNIT_TESTING = False
class Conv2dReLU(Base):
def __init__(
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
bias=True):
super().__init__()
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
def forward(self, x):
h = self.nn(x)
y = F.relu(h)
return y
class ResidualBlock(Base):
def __init__(self, n_channels, kernel, Conv2dAct):
super().__init__()
self.nn = torch.nn.Sequential(
Conv2dAct(n_channels, n_channels, kernel, padding=1),
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
)
def forward(self, x):
h = self.nn(x)
h = F.relu(h + x)
return h
class DenseLayer(Base):
def __init__(self, args, n_inputs, growth, Conv2dAct):
super().__init__()
conv1x1 = Conv2dAct(
n_inputs, n_inputs, kernel_size=1, stride=1,
padding=0, bias=True)
self.nn = torch.nn.Sequential(
conv1x1,
Conv2dAct(
n_inputs, growth, kernel_size=3, stride=1,
padding=1, bias=True),
)
def forward(self, x):
h = self.nn(x)
h = torch.cat([x, h], dim=1)
return h
class DenseBlock(Base):
def __init__(
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
super().__init__()
depth = args.densenet_depth
future_growth = n_outputs - n_inputs
layers = []
for d in range(depth):
growth = future_growth // (depth - d)
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
n_inputs += growth
future_growth -= growth
self.nn = torch.nn.Sequential(*layers)
def forward(self, x):
return self.nn(x)
class Identity(Base):
def __init__(self):
super.__init__()
def forward(self, x):
return x
class NN(Base):
def __init__(
self, args, c_in, c_out, height, width, nn_type, kernel=3):
super().__init__()
Conv2dAct = Conv2dReLU
n_channels = args.n_channels
if nn_type == 'shallow':
if args.network1x1 == 'standard':
conv1x1 = Conv2dAct(
n_channels, n_channels, kernel_size=1, stride=1,
padding=0, bias=False)
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
conv1x1]
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
elif nn_type == 'resnet':
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
ResidualBlock(n_channels, kernel, Conv2dAct),
ResidualBlock(n_channels, kernel, Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
]
elif nn_type == 'densenet':
layers = [
DenseBlock(
args=args,
n_inputs=c_in,
n_outputs=n_channels + c_in,
kernel=kernel,
Conv2dAct=Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
]
else:
raise ValueError
self.nn = torch.nn.Sequential(*layers)
# Set parameters of last conv-layer to zero.
if not UNIT_TESTING:
self.nn[-1].weight.data.zero_()
self.nn[-1].bias.data.zero_()
def forward(self, x):
return self.nn(x)

View file

@ -0,0 +1,164 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from utils.distributions import sample_discretized_logistic, \
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
sample_discretized_normal, sample_mixture_normal
from models.utils import Base
from .networks import NN
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
if variable_type == 'discrete':
if distribution_type == 'logistic':
if len(px) == 2:
return sample_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif len(px) == 3:
return sample_mixture_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif distribution_type == 'normal':
return sample_discretized_normal(
*px, inverse_bin_width=inverse_bin_width)
elif variable_type == 'continuous':
if distribution_type == 'logistic':
return sample_logistic(*px)
elif distribution_type == 'normal':
if len(px) == 2:
return sample_normal(*px)
elif len(px) == 3:
return sample_mixture_normal(*px)
elif distribution_type == 'steplogistic':
return sample_logistic(*px)
raise ValueError
class Prior(Base):
def __init__(self, size, args):
super().__init__()
c, h, w = size
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.n_mixtures = args.n_mixtures
if self.n_mixtures == 1:
self.mu = Parameter(torch.Tensor(c, h, w))
self.logs = Parameter(torch.Tensor(c, h, w))
elif self.n_mixtures > 1:
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.reset_parameters()
def reset_parameters(self):
self.mu.data.zero_()
if self.n_mixtures > 1:
self.pi_logit.data.zero_()
for i in range(self.n_mixtures):
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
self.logs.data.zero_()
def get_pz(self, n):
if self.n_mixtures == 1:
mu = self.mu.repeat(n, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
return mu, logs
elif self.n_mixtures > 1:
pi = F.softmax(self.pi_logit, dim=-1)
mu = self.mu.repeat(n, 1, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1, 1)
pi = pi.repeat(n, 1, 1, 1, 1)
return mu, logs, pi
def forward(self, z, ldj):
pz = self.get_pz(z.size(0))
return pz, z, ldj
def sample(self, n):
pz = self.get_pz(n)
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
return z_sample
def decode(self, states, decode_fn):
pz = self.get_pz(n=len(states))
states, z = decode_fn(states, pz)
return states, z
class SplitPrior(Base):
def __init__(self, c_in, factor_out, height, width, args):
super().__init__()
self.split_idx = c_in - factor_out
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.input_channel = c_in
self.nn = NN(
args=args,
c_in=c_in - factor_out,
c_out=factor_out * 2,
height=height,
width=width,
nn_type=args.splitprior_type)
def get_py(self, z):
h = self.nn(z)
mu = h[:, ::2, :, :]
logs = h[:, 1::2, :, :]
py = [mu, logs]
return py
def split(self, z):
z1 = z[:, :self.split_idx, :, :]
y = z[:, self.split_idx:, :, :]
return z1, y
def combine(self, z, y):
result = torch.cat([z, y], dim=1)
return result
def forward(self, z, ldj):
z, y = self.split(z)
py = self.get_py(z)
return py, y, z, ldj
def inverse(self, z, ldj, y):
# Sample if y is not given.
if y is None:
py = self.get_py(z)
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
z = self.combine(z, y)
return z, ldj
def decode(self, z, ldj, states, decode_fn):
py = self.get_py(z)
states, y = decode_fn(states, py)
return self.combine(z, y), ldj, states

View file

@ -0,0 +1,36 @@
import torch
class Base(torch.nn.Module):
"""
The base class for modules. That contains a disable round mode
"""
def __init__(self):
super().__init__()
def _set_child_attribute(self, attr, value):
r"""Sets the module in rounding mode.
This has any effect only on certain modules if variable type is
discrete.
Returns:
Module: self
"""
if hasattr(self, attr):
setattr(self, attr, value)
for module in self.modules():
if hasattr(module, attr):
setattr(module, attr, value)
return self
def set_temperature(self, value):
self._set_child_attribute("temperature", value)
def enable_hard_round(self, mode=True):
self._set_child_attribute("hard_round", mode)
def disable_hard_round(self, mode=True):
self.enable_hard_round(not mode)