feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
175
integer_discrete_flows/models/generative_flows.py
Normal file
175
integer_discrete_flows/models/generative_flows.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .priors import SplitPrior
|
||||
from .coupling import Coupling
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
def space_to_depth(x):
|
||||
xs = x.size()
|
||||
# Pick off every second element
|
||||
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
||||
# Transpose picked elements next to channels.
|
||||
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
||||
# Combine with channels.
|
||||
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
||||
return x
|
||||
|
||||
|
||||
def depth_to_space(x):
|
||||
xs = x.size()
|
||||
# Pick off elements from channels
|
||||
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
||||
# Transpose picked elements next to HW dimensions.
|
||||
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
||||
# Combine with HW dimensions.
|
||||
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
||||
return x
|
||||
|
||||
|
||||
def int_shape(x):
|
||||
return list(map(int, x.size()))
|
||||
|
||||
|
||||
class Flatten(Base):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class Reshape(Base):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.shape)
|
||||
|
||||
|
||||
class Reverse(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, reverse=False):
|
||||
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
||||
z = z[:, flip_idx, :, :]
|
||||
return z
|
||||
|
||||
|
||||
class Permute(Base):
|
||||
def __init__(self, n_channels):
|
||||
super().__init__()
|
||||
|
||||
permutation = np.arange(n_channels, dtype='int')
|
||||
np.random.shuffle(permutation)
|
||||
|
||||
permutation_inv = np.zeros(n_channels, dtype='int')
|
||||
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
||||
|
||||
self.permutation = torch.from_numpy(permutation)
|
||||
self.permutation_inv = torch.from_numpy(permutation_inv)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = z[:, self.permutation, :, :]
|
||||
else:
|
||||
z = z[:, self.permutation_inv, :, :]
|
||||
|
||||
return z, ldj
|
||||
|
||||
def InversePermute(self):
|
||||
inv_permute = Permute(len(self.permutation))
|
||||
inv_permute.permutation = self.permutation_inv
|
||||
inv_permute.permutation_inv = self.permutation
|
||||
return inv_permute
|
||||
|
||||
|
||||
class Squeeze(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = space_to_depth(z)
|
||||
else:
|
||||
z = depth_to_space(z)
|
||||
return z, ldj
|
||||
|
||||
|
||||
class GenerativeFlow(Base):
|
||||
def __init__(self, n_channels, height, width, args):
|
||||
super().__init__()
|
||||
layers = []
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
for level in range(args.n_levels):
|
||||
|
||||
for i in range(args.n_flows):
|
||||
perm_layer = Permute(n_channels)
|
||||
layers.append(perm_layer)
|
||||
|
||||
layers.append(
|
||||
Coupling(n_channels, height, width, args))
|
||||
|
||||
if level < args.n_levels - 1:
|
||||
if args.splitprior_type != 'none':
|
||||
# Standard splitprior
|
||||
factor_out = n_channels // 2
|
||||
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
||||
n_channels = n_channels - factor_out
|
||||
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
self.z_size = (n_channels, height, width)
|
||||
|
||||
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
||||
if not reverse:
|
||||
for l, layer in enumerate(self.layers):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
py, y, z, ldj = layer(z, ldj)
|
||||
pys += (py,)
|
||||
ys += (y,)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj)
|
||||
|
||||
else:
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
if len(ys) > 0:
|
||||
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
||||
# Pop last element
|
||||
ys = ys[:-1]
|
||||
else:
|
||||
z, ldj = layer.inverse(z, ldj, y=None)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj, pys, ys
|
||||
|
||||
def decode(self, z, ldj, state, decode_fn):
|
||||
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, SplitPrior):
|
||||
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj
|
||||
Reference in a new issue