This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/models/generative_flows.py
2025-11-07 12:54:36 +01:00

175 lines
4.7 KiB
Python

"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import numpy as np
from models.utils import Base
from .priors import SplitPrior
from .coupling import Coupling
UNIT_TESTING = False
def space_to_depth(x):
xs = x.size()
# Pick off every second element
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
# Transpose picked elements next to channels.
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
# Combine with channels.
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
return x
def depth_to_space(x):
xs = x.size()
# Pick off elements from channels
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
# Transpose picked elements next to HW dimensions.
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
# Combine with HW dimensions.
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
return x
def int_shape(x):
return list(map(int, x.size()))
class Flatten(Base):
def forward(self, x):
return x.view(x.size(0), -1)
class Reshape(Base):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(x.size(0), *self.shape)
class Reverse(Base):
def __init__(self):
super().__init__()
def forward(self, z, reverse=False):
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
z = z[:, flip_idx, :, :]
return z
class Permute(Base):
def __init__(self, n_channels):
super().__init__()
permutation = np.arange(n_channels, dtype='int')
np.random.shuffle(permutation)
permutation_inv = np.zeros(n_channels, dtype='int')
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
self.permutation = torch.from_numpy(permutation)
self.permutation_inv = torch.from_numpy(permutation_inv)
def forward(self, z, ldj, reverse=False):
if not reverse:
z = z[:, self.permutation, :, :]
else:
z = z[:, self.permutation_inv, :, :]
return z, ldj
def InversePermute(self):
inv_permute = Permute(len(self.permutation))
inv_permute.permutation = self.permutation_inv
inv_permute.permutation_inv = self.permutation
return inv_permute
class Squeeze(Base):
def __init__(self):
super().__init__()
def forward(self, z, ldj, reverse=False):
if not reverse:
z = space_to_depth(z)
else:
z = depth_to_space(z)
return z, ldj
class GenerativeFlow(Base):
def __init__(self, n_channels, height, width, args):
super().__init__()
layers = []
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
for level in range(args.n_levels):
for i in range(args.n_flows):
perm_layer = Permute(n_channels)
layers.append(perm_layer)
layers.append(
Coupling(n_channels, height, width, args))
if level < args.n_levels - 1:
if args.splitprior_type != 'none':
# Standard splitprior
factor_out = n_channels // 2
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
n_channels = n_channels - factor_out
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
self.layers = torch.nn.ModuleList(layers)
self.z_size = (n_channels, height, width)
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
if not reverse:
for l, layer in enumerate(self.layers):
if isinstance(layer, (SplitPrior)):
py, y, z, ldj = layer(z, ldj)
pys += (py,)
ys += (y,)
else:
z, ldj = layer(z, ldj)
else:
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, (SplitPrior)):
if len(ys) > 0:
z, ldj = layer.inverse(z, ldj, y=ys[-1])
# Pop last element
ys = ys[:-1]
else:
z, ldj = layer.inverse(z, ldj, y=None)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj, pys, ys
def decode(self, z, ldj, state, decode_fn):
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, SplitPrior):
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj