175 lines
4.7 KiB
Python
175 lines
4.7 KiB
Python
"""
|
|
Collection of flow strategies
|
|
"""
|
|
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
import numpy as np
|
|
from models.utils import Base
|
|
from .priors import SplitPrior
|
|
from .coupling import Coupling
|
|
|
|
|
|
UNIT_TESTING = False
|
|
|
|
|
|
def space_to_depth(x):
|
|
xs = x.size()
|
|
# Pick off every second element
|
|
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
|
# Transpose picked elements next to channels.
|
|
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
|
# Combine with channels.
|
|
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
|
return x
|
|
|
|
|
|
def depth_to_space(x):
|
|
xs = x.size()
|
|
# Pick off elements from channels
|
|
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
|
# Transpose picked elements next to HW dimensions.
|
|
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
|
# Combine with HW dimensions.
|
|
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
|
return x
|
|
|
|
|
|
def int_shape(x):
|
|
return list(map(int, x.size()))
|
|
|
|
|
|
class Flatten(Base):
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
class Reshape(Base):
|
|
def __init__(self, shape):
|
|
super().__init__()
|
|
self.shape = shape
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), *self.shape)
|
|
|
|
|
|
class Reverse(Base):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, z, reverse=False):
|
|
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
|
z = z[:, flip_idx, :, :]
|
|
return z
|
|
|
|
|
|
class Permute(Base):
|
|
def __init__(self, n_channels):
|
|
super().__init__()
|
|
|
|
permutation = np.arange(n_channels, dtype='int')
|
|
np.random.shuffle(permutation)
|
|
|
|
permutation_inv = np.zeros(n_channels, dtype='int')
|
|
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
|
|
|
self.permutation = torch.from_numpy(permutation)
|
|
self.permutation_inv = torch.from_numpy(permutation_inv)
|
|
|
|
def forward(self, z, ldj, reverse=False):
|
|
if not reverse:
|
|
z = z[:, self.permutation, :, :]
|
|
else:
|
|
z = z[:, self.permutation_inv, :, :]
|
|
|
|
return z, ldj
|
|
|
|
def InversePermute(self):
|
|
inv_permute = Permute(len(self.permutation))
|
|
inv_permute.permutation = self.permutation_inv
|
|
inv_permute.permutation_inv = self.permutation
|
|
return inv_permute
|
|
|
|
|
|
class Squeeze(Base):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, z, ldj, reverse=False):
|
|
if not reverse:
|
|
z = space_to_depth(z)
|
|
else:
|
|
z = depth_to_space(z)
|
|
return z, ldj
|
|
|
|
|
|
class GenerativeFlow(Base):
|
|
def __init__(self, n_channels, height, width, args):
|
|
super().__init__()
|
|
layers = []
|
|
layers.append(Squeeze())
|
|
n_channels *= 4
|
|
height //= 2
|
|
width //= 2
|
|
|
|
for level in range(args.n_levels):
|
|
|
|
for i in range(args.n_flows):
|
|
perm_layer = Permute(n_channels)
|
|
layers.append(perm_layer)
|
|
|
|
layers.append(
|
|
Coupling(n_channels, height, width, args))
|
|
|
|
if level < args.n_levels - 1:
|
|
if args.splitprior_type != 'none':
|
|
# Standard splitprior
|
|
factor_out = n_channels // 2
|
|
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
|
n_channels = n_channels - factor_out
|
|
|
|
layers.append(Squeeze())
|
|
n_channels *= 4
|
|
height //= 2
|
|
width //= 2
|
|
|
|
self.layers = torch.nn.ModuleList(layers)
|
|
self.z_size = (n_channels, height, width)
|
|
|
|
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
|
if not reverse:
|
|
for l, layer in enumerate(self.layers):
|
|
if isinstance(layer, (SplitPrior)):
|
|
py, y, z, ldj = layer(z, ldj)
|
|
pys += (py,)
|
|
ys += (y,)
|
|
|
|
else:
|
|
z, ldj = layer(z, ldj)
|
|
|
|
else:
|
|
for l, layer in reversed(list(enumerate(self.layers))):
|
|
if isinstance(layer, (SplitPrior)):
|
|
if len(ys) > 0:
|
|
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
|
# Pop last element
|
|
ys = ys[:-1]
|
|
else:
|
|
z, ldj = layer.inverse(z, ldj, y=None)
|
|
|
|
else:
|
|
z, ldj = layer(z, ldj, reverse=True)
|
|
|
|
return z, ldj, pys, ys
|
|
|
|
def decode(self, z, ldj, state, decode_fn):
|
|
|
|
for l, layer in reversed(list(enumerate(self.layers))):
|
|
if isinstance(layer, SplitPrior):
|
|
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
|
|
|
else:
|
|
z, ldj = layer(z, ldj, reverse=True)
|
|
|
|
return z, ldj
|