feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
19
integer_discrete_flows/LICENSE
Normal file
19
integer_discrete_flows/LICENSE
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
29
integer_discrete_flows/README.md
Normal file
29
integer_discrete_flows/README.md
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Integer Discrete Flows and Lossless Compression
|
||||||
|
|
||||||
|
This repository contains the code for the experiments presented in [1].
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### CIFAR10 setup:
|
||||||
|
```
|
||||||
|
python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type ‘densenet’
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### ImageNet32 setup:
|
||||||
|
```
|
||||||
|
python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### ImageNet64 setup:
|
||||||
|
```
|
||||||
|
python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64
|
||||||
|
```
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
The Robert Bosch GmbH is acknowledged for financial support.
|
||||||
|
|
||||||
|
# References
|
||||||
|
[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019).
|
||||||
|
|
||||||
0
integer_discrete_flows/coding/__init__.py
Normal file
0
integer_discrete_flows/coding/__init__.py
Normal file
132
integer_discrete_flows/coding/coder.py
Normal file
132
integer_discrete_flows/coding/coder.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
import numpy as np
|
||||||
|
from . import rans
|
||||||
|
from utils.distributions import discretized_logistic_cdf, \
|
||||||
|
mixture_discretized_logistic_cdf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
precision = 24
|
||||||
|
n_bins = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
|
||||||
|
if variable_type == 'discrete':
|
||||||
|
if distribution_type == 'logistic':
|
||||||
|
if len(pz) == 2:
|
||||||
|
return discretized_logistic_cdf(
|
||||||
|
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||||
|
elif len(pz) == 3:
|
||||||
|
return mixture_discretized_logistic_cdf(
|
||||||
|
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||||
|
elif distribution_type == 'normal':
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif variable_type == 'continuous':
|
||||||
|
if distribution_type == 'logistic':
|
||||||
|
pass
|
||||||
|
elif distribution_type == 'normal':
|
||||||
|
pass
|
||||||
|
elif distribution_type == 'steplogistic':
|
||||||
|
pass
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
||||||
|
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
|
||||||
|
MEAN = torch.round(mean / bin_width).long()
|
||||||
|
|
||||||
|
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||||
|
bin_locations = bin_locations.float() * bin_width
|
||||||
|
bin_locations = bin_locations.to(device=pz[0].device)
|
||||||
|
|
||||||
|
pz = [param[:, :, :, :, None] for param in pz]
|
||||||
|
cdf = cdf_fn(
|
||||||
|
bin_locations - bin_width,
|
||||||
|
pz,
|
||||||
|
variable_type,
|
||||||
|
distribution_type,
|
||||||
|
1./bin_width).cpu().numpy()
|
||||||
|
|
||||||
|
# Compute CDFs, reweigh to give all bins at least
|
||||||
|
# 1 / (2^precision) probability.
|
||||||
|
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
|
||||||
|
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
|
||||||
|
+ np.arange(n_bins)
|
||||||
|
|
||||||
|
return CDFs, MEAN
|
||||||
|
|
||||||
|
|
||||||
|
def encode_sample(
|
||||||
|
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
|
||||||
|
if state is None:
|
||||||
|
state = rans.x_init
|
||||||
|
else:
|
||||||
|
state = rans.unflatten(state)
|
||||||
|
|
||||||
|
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||||
|
|
||||||
|
# z is transformed to Z to match the indices for the CDFs array
|
||||||
|
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
|
||||||
|
Z = Z.cpu().numpy()
|
||||||
|
|
||||||
|
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
|
||||||
|
print('Z out of allowed range of values, canceling compression')
|
||||||
|
return None
|
||||||
|
|
||||||
|
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
|
||||||
|
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
|
||||||
|
statfun = statfun_encode(cdf)
|
||||||
|
state = rans.append_symbol(statfun, precision)(state, symbol)
|
||||||
|
|
||||||
|
state = rans.flatten(state)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def decode_sample(
|
||||||
|
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||||
|
state = rans.unflatten(state)
|
||||||
|
|
||||||
|
device = pz[0].device
|
||||||
|
size = pz[0].size()[0:4]
|
||||||
|
|
||||||
|
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||||
|
|
||||||
|
CDFs = CDFs.reshape(-1, n_bins)
|
||||||
|
result = np.zeros(len(CDFs), dtype=int)
|
||||||
|
for i, cdf in enumerate(CDFs):
|
||||||
|
statfun = statfun_decode(cdf)
|
||||||
|
state, symbol = rans.pop_symbol(statfun, precision)(state)
|
||||||
|
result[i] = symbol
|
||||||
|
|
||||||
|
Z_flat = torch.from_numpy(result).to(device)
|
||||||
|
Z = Z_flat.view(size) - n_bins // 2 + MEAN
|
||||||
|
|
||||||
|
z = Z.float() * bin_width
|
||||||
|
|
||||||
|
state = rans.flatten(state)
|
||||||
|
|
||||||
|
return state, z
|
||||||
|
|
||||||
|
|
||||||
|
def statfun_encode(CDF):
|
||||||
|
def _statfun_encode(symbol):
|
||||||
|
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
|
||||||
|
return _statfun_encode
|
||||||
|
|
||||||
|
|
||||||
|
def statfun_decode(CDF):
|
||||||
|
def _statfun_decode(cf):
|
||||||
|
# Search such that CDF[s] <= cf < CDF[s]
|
||||||
|
s = np.searchsorted(CDF, cf, side='right')
|
||||||
|
s = s - 1
|
||||||
|
start, freq = statfun_encode(CDF)(s)
|
||||||
|
return s, (start, freq)
|
||||||
|
return _statfun_decode
|
||||||
|
|
||||||
|
|
||||||
|
def encode(x, symbol):
|
||||||
|
return rans.append_symbol(statfun_encode, precision)(x, symbol)
|
||||||
|
|
||||||
|
|
||||||
|
def decode(x):
|
||||||
|
return rans.pop_symbol(statfun_decode, precision)(x)
|
||||||
67
integer_discrete_flows/coding/rans.py
Normal file
67
integer_discrete_flows/coding/rans.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""
|
||||||
|
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h
|
||||||
|
|
||||||
|
ROUGH GUIDE:
|
||||||
|
We use the pythonic names 'append' and 'pop' for encoding and decoding
|
||||||
|
respectively. The compressed state 'x' is an immutable stack, implemented using
|
||||||
|
a cons list.
|
||||||
|
|
||||||
|
x: the current stack-like state of the encoder/decoder.
|
||||||
|
|
||||||
|
precision: the natural numbers are divided into ranges of size 2^precision.
|
||||||
|
|
||||||
|
start & freq: start indicates the beginning of the range in [0, 2^precision-1]
|
||||||
|
that the current symbol is represented by. freq is the length of the range.
|
||||||
|
freq is chosen such that p(symbol) ~= freq/2^precision.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
|
rans_l = 1 << 31 # the lower bound of the normalisation interval
|
||||||
|
tail_bits = (1 << 32) - 1
|
||||||
|
|
||||||
|
x_init = (rans_l, ())
|
||||||
|
|
||||||
|
def append(x, start, freq, precision):
|
||||||
|
"""Encodes a symbol with range [start, start + freq). All frequencies are
|
||||||
|
assumed to sum to "1 << precision", and the resulting bits get written to
|
||||||
|
x."""
|
||||||
|
if x[0] >= ((rans_l >> precision) << 32) * freq:
|
||||||
|
x = (x[0] >> 32, (x[0] & tail_bits, x[1]))
|
||||||
|
return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1]
|
||||||
|
|
||||||
|
def pop(x_, precision):
|
||||||
|
"""Advances in the bit stream by "popping" a single symbol with range start
|
||||||
|
"start" and frequency "freq"."""
|
||||||
|
cf = x_[0] & ((1 << precision) - 1)
|
||||||
|
def pop(start, freq):
|
||||||
|
x = freq * (x_[0] >> precision) + cf - start, x_[1]
|
||||||
|
return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x
|
||||||
|
return cf, pop
|
||||||
|
|
||||||
|
def append_symbol(statfun, precision):
|
||||||
|
def append_(x, symbol):
|
||||||
|
start, freq = statfun(symbol)
|
||||||
|
return append(x, start, freq, precision)
|
||||||
|
return append_
|
||||||
|
|
||||||
|
def pop_symbol(statfun, precision):
|
||||||
|
def pop_(x):
|
||||||
|
cf, pop_fun = pop(x, precision)
|
||||||
|
symbol, (start, freq) = statfun(cf)
|
||||||
|
return pop_fun(start, freq), symbol
|
||||||
|
return pop_
|
||||||
|
|
||||||
|
def flatten(x):
|
||||||
|
"""Flatten a rans state x into a 1d numpy array."""
|
||||||
|
out, x = [x[0] >> 32, x[0]], x[1]
|
||||||
|
while x:
|
||||||
|
x_head, x = x
|
||||||
|
out.append(x_head)
|
||||||
|
return np.asarray(out, dtype=np.uint32)
|
||||||
|
|
||||||
|
def unflatten(arr):
|
||||||
|
"""Unflatten a 1d numpy array into a rans state."""
|
||||||
|
return (int(arr[0]) << 32 | int(arr[1]),
|
||||||
|
reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))
|
||||||
188
integer_discrete_flows/experiment_coding.py
Normal file
188
integer_discrete_flows/experiment_coding.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
# !/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils.load_data import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||||
|
|
||||||
|
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||||
|
metavar='DATASET',
|
||||||
|
help='Dataset choice.')
|
||||||
|
|
||||||
|
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
|
||||||
|
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
||||||
|
|
||||||
|
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
||||||
|
help='how many batches to wait before logging training status')
|
||||||
|
|
||||||
|
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
||||||
|
help='Evaluate per how many epochs')
|
||||||
|
|
||||||
|
|
||||||
|
# optimization settings
|
||||||
|
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
||||||
|
help='number of epochs to train (default: 2000)')
|
||||||
|
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
||||||
|
help='number of early stopping epochs')
|
||||||
|
|
||||||
|
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
|
||||||
|
help='input batch size for training (default: 100)')
|
||||||
|
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
||||||
|
help='learning rate')
|
||||||
|
parser.add_argument('--warmup', type=int, default=10,
|
||||||
|
help='number of warmup epochs')
|
||||||
|
|
||||||
|
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||||
|
help='data augmentation level')
|
||||||
|
|
||||||
|
parser.add_argument('--no_decode', action='store_true', default=False,
|
||||||
|
help='disables decoding')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
|
||||||
|
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_images(img, model, decode):
|
||||||
|
batchsize, img_c, img_h, img_w = img.size()
|
||||||
|
c, h, w = model.args.input_size
|
||||||
|
|
||||||
|
assert img_h == img_w and h == w
|
||||||
|
|
||||||
|
if img_h != h:
|
||||||
|
assert img_h % h == 0
|
||||||
|
steps = img_h // h
|
||||||
|
|
||||||
|
states = [[] for i in range(batchsize)]
|
||||||
|
state_sizes = [0 for i in range(batchsize)]
|
||||||
|
bpd = [0 for i in range(batchsize)]
|
||||||
|
error = 0
|
||||||
|
|
||||||
|
for j in range(steps):
|
||||||
|
for i in range(steps):
|
||||||
|
r = encode_patches(
|
||||||
|
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
|
||||||
|
for b in range(batchsize):
|
||||||
|
|
||||||
|
if r[0][b] is None:
|
||||||
|
states[b].append(None)
|
||||||
|
else:
|
||||||
|
states[b].extend(r[0][b])
|
||||||
|
state_sizes[b] += r[1][b]
|
||||||
|
bpd[b] += r[2][b] / steps**2
|
||||||
|
error += r[3]
|
||||||
|
return states, state_sizes, bpd, error
|
||||||
|
else:
|
||||||
|
return encode_patches(img, model, decode)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_patches(imgs, model, decode):
|
||||||
|
batchsize, img_c, img_h, img_w = imgs.size()
|
||||||
|
c, h, w = model.args.input_size
|
||||||
|
assert img_h == h and img_w == w
|
||||||
|
|
||||||
|
states = model.encode(imgs)
|
||||||
|
|
||||||
|
bpd = model.forward(imgs)[1].cpu().numpy()
|
||||||
|
|
||||||
|
state_sizes = []
|
||||||
|
error = 0
|
||||||
|
|
||||||
|
for b in range(batchsize):
|
||||||
|
if states[b] is None:
|
||||||
|
# Using escape bit ;)
|
||||||
|
state_sizes += [8 * img_c * img_h * img_w + 1]
|
||||||
|
|
||||||
|
# Error remains unchanged.
|
||||||
|
print('Escaping, not encoding.')
|
||||||
|
|
||||||
|
else:
|
||||||
|
if decode:
|
||||||
|
x_recon = model.decode([states[b]])
|
||||||
|
|
||||||
|
error += torch.sum(
|
||||||
|
torch.abs(x_recon.int() - imgs[b].int())).item()
|
||||||
|
|
||||||
|
# Append state plus an escape bit
|
||||||
|
state_sizes += [32 * len(states[b]) + 1]
|
||||||
|
|
||||||
|
return states, state_sizes, bpd, error
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, kwargs):
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
args.snap_dir = snap_dir = \
|
||||||
|
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# SNAPSHOTS
|
||||||
|
# ==================================================================================================================
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# LOAD DATA
|
||||||
|
# ==================================================================================================================
|
||||||
|
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||||
|
|
||||||
|
final_model = torch.load(snap_dir + 'a.model')
|
||||||
|
if hasattr(final_model, 'module'):
|
||||||
|
final_model = final_model.module
|
||||||
|
final_model = final_model.cuda()
|
||||||
|
|
||||||
|
sizes = []
|
||||||
|
errors = []
|
||||||
|
bpds = []
|
||||||
|
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
t = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, _ in test_loader:
|
||||||
|
if args.cuda:
|
||||||
|
data = data.cuda()
|
||||||
|
|
||||||
|
state, state_sizes, bpd, error = \
|
||||||
|
encode_images(data, final_model, decode=not args.no_decode)
|
||||||
|
|
||||||
|
errors += [error]
|
||||||
|
bpds.extend(bpd)
|
||||||
|
sizes.extend(state_sizes)
|
||||||
|
|
||||||
|
t += len(data)
|
||||||
|
|
||||||
|
print(
|
||||||
|
'Examples: {}/{} bpd compression: {:.3f} error: {},'
|
||||||
|
' analytical bpd {:.3f}'.format(
|
||||||
|
t, len(test_loader.dataset),
|
||||||
|
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||||
|
np.sum(errors),
|
||||||
|
np.mean(bpds)
|
||||||
|
))
|
||||||
|
|
||||||
|
if args.no_decode:
|
||||||
|
print('Not testing decoding.')
|
||||||
|
else:
|
||||||
|
print('Error: {}'.format(np.sum(errors)))
|
||||||
|
|
||||||
|
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
|
||||||
|
print('Final bpd: {:.3f} error: {}'.format(
|
||||||
|
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||||
|
np.sum(errors)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
run(args, kwargs)
|
||||||
105
integer_discrete_flows/experiment_progressive_loading.py
Normal file
105
integer_discrete_flows/experiment_progressive_loading.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
# !/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
import os
|
||||||
|
from optimization.training import train, evaluate
|
||||||
|
from utils.load_data import load_dataset
|
||||||
|
from utils.plotting import plot_training_curve
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||||
|
|
||||||
|
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||||
|
metavar='DATASET',
|
||||||
|
help='Dataset choice.')
|
||||||
|
|
||||||
|
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
||||||
|
help='input batch size for training (default: 100)')
|
||||||
|
|
||||||
|
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||||
|
help='data augmentation level')
|
||||||
|
|
||||||
|
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
|
||||||
|
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, kwargs):
|
||||||
|
|
||||||
|
args.snap_dir = snap_dir = \
|
||||||
|
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# SNAPSHOTS
|
||||||
|
# ==================================================================================================================
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# LOAD DATA
|
||||||
|
# ==================================================================================================================
|
||||||
|
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||||
|
|
||||||
|
final_model = torch.load(snap_dir + 'a.model')
|
||||||
|
if hasattr(final_model, 'module'):
|
||||||
|
final_model = final_model.module
|
||||||
|
|
||||||
|
from models.backround import SmoothRound
|
||||||
|
for module in final_model.modules():
|
||||||
|
if isinstance(module, SmoothRound):
|
||||||
|
module._round_decay = 1.
|
||||||
|
|
||||||
|
exp_dir = snap_dir + 'partials/'
|
||||||
|
os.makedirs(exp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, _ in test_loader:
|
||||||
|
|
||||||
|
if args.cuda:
|
||||||
|
data = data.cuda()
|
||||||
|
|
||||||
|
for i in range(len(data)):
|
||||||
|
_, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1])
|
||||||
|
|
||||||
|
for j in range(len(ys) + 1):
|
||||||
|
x_recon = final_model.inverse(
|
||||||
|
z,
|
||||||
|
ys[len(ys) - j:])
|
||||||
|
|
||||||
|
images.append(x_recon.float())
|
||||||
|
|
||||||
|
if i == 10:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
for j in range(len(ys) + 1):
|
||||||
|
|
||||||
|
grid = make_grid(
|
||||||
|
torch.stack(images[j::len(ys) + 1], dim=0).squeeze(),
|
||||||
|
nrow=11, padding=0,
|
||||||
|
normalize=True, range=None,
|
||||||
|
scale_each=False, pad_value=0)
|
||||||
|
|
||||||
|
imageio.imwrite(
|
||||||
|
exp_dir + 'loaded{j}.png'.format(j=j),
|
||||||
|
grid.cpu().numpy().transpose(1, 2, 0))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
run(args, kwargs)
|
||||||
285
integer_discrete_flows/main_experiment.py
Normal file
285
integer_discrete_flows/main_experiment.py
Normal file
|
|
@ -0,0 +1,285 @@
|
||||||
|
# !/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from optimization.training import train, evaluate
|
||||||
|
from utils.load_data import load_dataset
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||||
|
|
||||||
|
parser.add_argument('-d', '--dataset', type=str, default='cifar10',
|
||||||
|
choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||||
|
metavar='DATASET',
|
||||||
|
help='Dataset choice.')
|
||||||
|
|
||||||
|
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
|
||||||
|
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
||||||
|
|
||||||
|
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
||||||
|
help='how many batches to wait before logging training status')
|
||||||
|
|
||||||
|
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
||||||
|
help='Evaluate per how many epochs')
|
||||||
|
|
||||||
|
parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
|
||||||
|
help='output directory for model snapshots etc.')
|
||||||
|
|
||||||
|
fp = parser.add_mutually_exclusive_group(required=False)
|
||||||
|
fp.add_argument('-te', '--testing', action='store_true', dest='testing',
|
||||||
|
help='evaluate on test set after training')
|
||||||
|
fp.add_argument('-va', '--validation', action='store_false', dest='testing',
|
||||||
|
help='only evaluate on validation set')
|
||||||
|
parser.set_defaults(testing=True)
|
||||||
|
|
||||||
|
# optimization settings
|
||||||
|
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
||||||
|
help='number of epochs to train (default: 2000)')
|
||||||
|
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
||||||
|
help='number of early stopping epochs')
|
||||||
|
|
||||||
|
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
||||||
|
help='input batch size for training (default: 100)')
|
||||||
|
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
||||||
|
help='learning rate')
|
||||||
|
parser.add_argument('--warmup', type=int, default=10,
|
||||||
|
help='number of warmup epochs')
|
||||||
|
|
||||||
|
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||||
|
help='data augmentation level')
|
||||||
|
|
||||||
|
parser.add_argument('--variable_type', type=str, default='discrete',
|
||||||
|
help='variable type of data distribution: discrete/continuous',
|
||||||
|
choices=['discrete', 'continuous'])
|
||||||
|
parser.add_argument('--distribution_type', type=str, default='logistic',
|
||||||
|
choices=['logistic', 'normal', 'steplogistic'],
|
||||||
|
help='distribution type: logistic/normal')
|
||||||
|
parser.add_argument('--n_flows', type=int, default=8,
|
||||||
|
help='number of flows per level')
|
||||||
|
parser.add_argument('--n_levels', type=int, default=3,
|
||||||
|
help='number of levels')
|
||||||
|
|
||||||
|
parser.add_argument('--n_bits', type=int, default=8,
|
||||||
|
help='')
|
||||||
|
|
||||||
|
# ---------------- SETTINGS CONCERNING NETWORKS -------------
|
||||||
|
parser.add_argument('--densenet_depth', type=int, default=8,
|
||||||
|
help='Depth of densenets')
|
||||||
|
parser.add_argument('--n_channels', type=int, default=512,
|
||||||
|
help='number of channels in coupling and splitprior')
|
||||||
|
# ---------------- ----------------------------- -------------
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
|
||||||
|
parser.add_argument('--coupling_type', type=str, default='shallow',
|
||||||
|
choices=['shallow', 'resnet', 'densenet'],
|
||||||
|
help='Type of coupling layer')
|
||||||
|
parser.add_argument('--splitfactor', default=0, type=int,
|
||||||
|
help='Split factor for coupling layers.')
|
||||||
|
|
||||||
|
parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
|
||||||
|
help='Split coupling layer on quarter')
|
||||||
|
parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
|
||||||
|
parser.set_defaults(split_quarter=True)
|
||||||
|
# ---------------- ----------------------------------- -------------
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
|
||||||
|
parser.add_argument('--splitprior_type', type=str, default='shallow',
|
||||||
|
choices=['none', 'shallow', 'resnet', 'densenet'],
|
||||||
|
help='Type of splitprior. Use \'none\' for no splitprior')
|
||||||
|
# ---------------- ------------------------------- -------------
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- SETTINGS CONCERNING PRIORS -------------
|
||||||
|
parser.add_argument('--n_mixtures', type=int, default=1,
|
||||||
|
help='number of mixtures')
|
||||||
|
# ---------------- ------------------------------- -------------
|
||||||
|
|
||||||
|
parser.add_argument('--hard_round', dest='hard_round', action='store_true',
|
||||||
|
help='Rounding of translation in discrete models. Weird '
|
||||||
|
'probabilistic implications, only for experimental phase')
|
||||||
|
parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
|
||||||
|
parser.set_defaults(hard_round=True)
|
||||||
|
|
||||||
|
parser.add_argument('--round_approx', type=str, default='smooth',
|
||||||
|
choices=['smooth', 'stochastic'])
|
||||||
|
|
||||||
|
parser.add_argument('--lr_decay', default=0.999, type=float,
|
||||||
|
help='Learning rate')
|
||||||
|
|
||||||
|
parser.add_argument('--temperature', default=1.0, type=float,
|
||||||
|
help='Temperature used for BackRound. It is used in '
|
||||||
|
'the the SmoothRound module. '
|
||||||
|
'(default=1.0')
|
||||||
|
|
||||||
|
# gpu/cpu
|
||||||
|
parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU',
|
||||||
|
help='choose GPU to run on.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
|
||||||
|
if args.manual_seed is None:
|
||||||
|
args.manual_seed = random.randint(1, 100000)
|
||||||
|
random.seed(args.manual_seed)
|
||||||
|
torch.manual_seed(args.manual_seed)
|
||||||
|
np.random.seed(args.manual_seed)
|
||||||
|
|
||||||
|
|
||||||
|
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, kwargs):
|
||||||
|
|
||||||
|
print('\nMODEL SETTINGS: \n', args, '\n')
|
||||||
|
print("Random Seed: ", args.manual_seed)
|
||||||
|
|
||||||
|
if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5:
|
||||||
|
args.evaluate_interval_epochs = 5
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# SNAPSHOTS
|
||||||
|
# ==================================================================================================================
|
||||||
|
args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
|
||||||
|
args.model_signature = args.model_signature.replace(':', '_')
|
||||||
|
|
||||||
|
snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset)
|
||||||
|
snap_dir = snapshots_path
|
||||||
|
|
||||||
|
snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels)
|
||||||
|
|
||||||
|
snap_dir = snap_dir + '__' + args.model_signature + '/'
|
||||||
|
|
||||||
|
args.snap_dir = snap_dir
|
||||||
|
|
||||||
|
if not os.path.exists(snap_dir):
|
||||||
|
os.makedirs(snap_dir)
|
||||||
|
|
||||||
|
with open(snap_dir + 'log.txt', 'a') as ff:
|
||||||
|
print('\nMODEL SETTINGS: \n', args, '\n', file=ff)
|
||||||
|
|
||||||
|
# SAVING
|
||||||
|
torch.save(args, snap_dir + '.config')
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# LOAD DATA
|
||||||
|
# ==================================================================================================================
|
||||||
|
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# SELECT MODEL
|
||||||
|
# ==================================================================================================================
|
||||||
|
# flow parameters and architecture choice are passed on to model through args
|
||||||
|
print(args.input_size)
|
||||||
|
|
||||||
|
import models.Model as Model
|
||||||
|
|
||||||
|
model = Model.Model(args)
|
||||||
|
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.set_temperature(args.temperature)
|
||||||
|
model.enable_hard_round(args.hard_round)
|
||||||
|
|
||||||
|
model_sample = model
|
||||||
|
|
||||||
|
# ====================================
|
||||||
|
# INIT
|
||||||
|
# ====================================
|
||||||
|
# data dependend initialization on CPU
|
||||||
|
for batch_idx, (data, _) in enumerate(train_loader):
|
||||||
|
model(data)
|
||||||
|
break
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
model = torch.nn.DataParallel(model, dim=0)
|
||||||
|
|
||||||
|
model.to(args.device)
|
||||||
|
|
||||||
|
def lr_lambda(epoch):
|
||||||
|
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
||||||
|
optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7)
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# TRAINING
|
||||||
|
# ==================================================================================================================
|
||||||
|
train_bpd = []
|
||||||
|
val_bpd = []
|
||||||
|
|
||||||
|
# for early stopping
|
||||||
|
best_val_bpd = np.inf
|
||||||
|
best_train_bpd = np.inf
|
||||||
|
epoch = 0
|
||||||
|
|
||||||
|
train_times = []
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
t_start = time.time()
|
||||||
|
scheduler.step()
|
||||||
|
tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args)
|
||||||
|
train_bpd.append(tr_bpd)
|
||||||
|
train_times.append(time.time()-t_start)
|
||||||
|
print('One training epoch took %.2f seconds' % (time.time()-t_start))
|
||||||
|
|
||||||
|
if epoch < 25 or epoch % args.evaluate_interval_epochs == 0:
|
||||||
|
v_loss, v_bpd = evaluate(
|
||||||
|
train_loader, val_loader, model, model_sample, args,
|
||||||
|
epoch=epoch, file=snap_dir + 'log.txt')
|
||||||
|
|
||||||
|
val_bpd.append(v_bpd)
|
||||||
|
|
||||||
|
# Model save based on TRAIN performance (is heavily correlated with validation performance.)
|
||||||
|
if np.mean(tr_bpd) < best_train_bpd:
|
||||||
|
best_train_bpd = np.mean(tr_bpd)
|
||||||
|
best_val_bpd = v_bpd
|
||||||
|
torch.save(model.module, snap_dir + 'a.model')
|
||||||
|
torch.save(optimizer, snap_dir + 'a.optimizer')
|
||||||
|
print('->model saved<-')
|
||||||
|
|
||||||
|
print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format(
|
||||||
|
best_train_bpd, best_val_bpd))
|
||||||
|
|
||||||
|
if math.isnan(v_loss):
|
||||||
|
raise ValueError('NaN encountered!')
|
||||||
|
|
||||||
|
train_bpd = np.hstack(train_bpd)
|
||||||
|
val_bpd = np.array(val_bpd)
|
||||||
|
|
||||||
|
# training time per epoch
|
||||||
|
train_times = np.array(train_times)
|
||||||
|
mean_train_time = np.mean(train_times)
|
||||||
|
std_train_time = np.std(train_times, ddof=1)
|
||||||
|
print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time))
|
||||||
|
|
||||||
|
# ==================================================================================================================
|
||||||
|
# EVALUATION
|
||||||
|
# ==================================================================================================================
|
||||||
|
final_model = torch.load(snap_dir + 'a.model')
|
||||||
|
test_loss, test_bpd = evaluate(
|
||||||
|
train_loader, test_loader, final_model, final_model, args,
|
||||||
|
epoch=epoch, file=snap_dir + 'test_log.txt')
|
||||||
|
|
||||||
|
print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
run(args, kwargs)
|
||||||
BIN
integer_discrete_flows/media/IDF_poster.pdf
Normal file
BIN
integer_discrete_flows/media/IDF_poster.pdf
Normal file
Binary file not shown.
BIN
integer_discrete_flows/media/IDF_slides.pdf
Normal file
BIN
integer_discrete_flows/media/IDF_slides.pdf
Normal file
Binary file not shown.
191
integer_discrete_flows/models/Model.py
Normal file
191
integer_discrete_flows/models/Model.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
import torch
|
||||||
|
import models.generative_flows as generative_flows
|
||||||
|
import numpy as np
|
||||||
|
from models.utils import Base
|
||||||
|
from .priors import Prior
|
||||||
|
from optimization.loss import compute_loss_array
|
||||||
|
from coding.coder import encode_sample, decode_sample
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(Base):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.n_bits = args.n_bits
|
||||||
|
self.variable_type = args.variable_type
|
||||||
|
self.input_size = args.input_size
|
||||||
|
|
||||||
|
def forward(self, x, ldj, reverse=False):
|
||||||
|
domain = 2.**self.n_bits
|
||||||
|
|
||||||
|
if self.variable_type == 'discrete':
|
||||||
|
# Discrete variables will be measured on intervals sized 1/domain.
|
||||||
|
# Hence, there is no need to change the log Jacobian determinant.
|
||||||
|
dldj = 0
|
||||||
|
elif self.variable_type == 'continuous':
|
||||||
|
dldj = -np.log(domain) * np.prod(self.input_size)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
x = (x - domain / 2) / domain
|
||||||
|
ldj += dldj
|
||||||
|
else:
|
||||||
|
x = x * domain + domain / 2
|
||||||
|
ldj -= dldj
|
||||||
|
|
||||||
|
return x, ldj
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Base):
|
||||||
|
"""
|
||||||
|
The base VAE class containing gated convolutional encoder and decoder
|
||||||
|
architecture. Can be used as a base class for VAE's with normalizing flows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.variable_type = args.variable_type
|
||||||
|
self.distribution_type = args.distribution_type
|
||||||
|
|
||||||
|
n_channels, height, width = args.input_size
|
||||||
|
|
||||||
|
self.normalize = Normalize(args)
|
||||||
|
|
||||||
|
self.flow = generative_flows.GenerativeFlow(
|
||||||
|
n_channels, height, width, args)
|
||||||
|
|
||||||
|
self.n_bits = args.n_bits
|
||||||
|
|
||||||
|
self.z_size = self.flow.z_size
|
||||||
|
|
||||||
|
self.prior = Prior(self.z_size, args)
|
||||||
|
|
||||||
|
def dequantize(self, x):
|
||||||
|
if self.training:
|
||||||
|
x = x + torch.rand_like(x)
|
||||||
|
else:
|
||||||
|
# Required for stability.
|
||||||
|
alpha = 1e-3
|
||||||
|
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def loss(self, pz, z, pys, ys, ldj):
|
||||||
|
batchsize = z.size(0)
|
||||||
|
loss, bpd, bpd_per_prior = \
|
||||||
|
compute_loss_array(pz, z, pys, ys, ldj, self.args)
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if hasattr(module, 'auxillary_loss'):
|
||||||
|
loss += module.auxillary_loss() / batchsize
|
||||||
|
|
||||||
|
return loss, bpd, bpd_per_prior
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Evaluates the model as a whole, encodes and decodes. Note that the log
|
||||||
|
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
|
||||||
|
"""
|
||||||
|
# Decode z to x.
|
||||||
|
|
||||||
|
assert x.dtype == torch.uint8
|
||||||
|
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||||
|
if self.variable_type == 'continuous':
|
||||||
|
x = self.dequantize(x)
|
||||||
|
elif self.variable_type == 'discrete':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
x, ldj = self.normalize(x, ldj)
|
||||||
|
|
||||||
|
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
|
||||||
|
|
||||||
|
pz, z, ldj = self.prior(z, ldj)
|
||||||
|
|
||||||
|
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
|
||||||
|
|
||||||
|
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
|
||||||
|
|
||||||
|
def inverse(self, z, ys):
|
||||||
|
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||||
|
x, ldj, pys, py = \
|
||||||
|
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
|
||||||
|
|
||||||
|
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||||
|
|
||||||
|
x_uint8 = torch.clamp(x, min=0, max=255).to(
|
||||||
|
torch.uint8)
|
||||||
|
|
||||||
|
return x_uint8
|
||||||
|
|
||||||
|
def sample(self, n):
|
||||||
|
z_sample = self.prior.sample(n)
|
||||||
|
|
||||||
|
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
|
||||||
|
x_sample, ldj, pys, py = \
|
||||||
|
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
|
||||||
|
|
||||||
|
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
|
||||||
|
|
||||||
|
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
|
||||||
|
torch.uint8)
|
||||||
|
|
||||||
|
return x_sample_uint8
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
batchsize = x.size(0)
|
||||||
|
_, _, _, pz, z, pys, ys, _ = self.forward(x)
|
||||||
|
|
||||||
|
pjs = list(pys) + [pz]
|
||||||
|
js = list(ys) + [z]
|
||||||
|
|
||||||
|
states = []
|
||||||
|
|
||||||
|
for b in range(batchsize):
|
||||||
|
state = None
|
||||||
|
for pj, j in zip(pjs, js):
|
||||||
|
pj_b = [param[b:b+1] for param in pj]
|
||||||
|
j_b = j[b:b+1]
|
||||||
|
|
||||||
|
state = encode_sample(
|
||||||
|
j_b, pj_b, self.variable_type,
|
||||||
|
self.distribution_type, state=state)
|
||||||
|
if state is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
states.append(state)
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
def decode(self, states):
|
||||||
|
def decode_fn(states, pj):
|
||||||
|
states = list(states)
|
||||||
|
j = []
|
||||||
|
|
||||||
|
for b in range(len(states)):
|
||||||
|
pj_b = [param[b:b+1] for param in pj]
|
||||||
|
|
||||||
|
states[b], j_b = decode_sample(
|
||||||
|
states[b], pj_b, self.variable_type,
|
||||||
|
self.distribution_type)
|
||||||
|
|
||||||
|
j.append(j_b)
|
||||||
|
|
||||||
|
j = torch.cat(j, dim=0)
|
||||||
|
return states, j
|
||||||
|
|
||||||
|
states, z = self.prior.decode(states, decode_fn=decode_fn)
|
||||||
|
|
||||||
|
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||||
|
|
||||||
|
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
|
||||||
|
|
||||||
|
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||||
|
x = x.to(dtype=torch.uint8)
|
||||||
|
|
||||||
|
return x
|
||||||
0
integer_discrete_flows/models/__init__.py
Normal file
0
integer_discrete_flows/models/__init__.py
Normal file
151
integer_discrete_flows/models/backround.py
Normal file
151
integer_discrete_flows/models/backround.py
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models.utils import Base
|
||||||
|
|
||||||
|
|
||||||
|
class RoundStraightThrough(torch.autograd.Function):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
rounded = torch.round(input, out=None)
|
||||||
|
return rounded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = grad_output.clone()
|
||||||
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
_round_straightthrough = RoundStraightThrough().apply
|
||||||
|
|
||||||
|
|
||||||
|
def _stacked_sigmoid(x, temperature, n_approx=3):
|
||||||
|
|
||||||
|
x_ = x - 0.5
|
||||||
|
rounded = torch.round(x_)
|
||||||
|
x_remainder = x_ - rounded
|
||||||
|
|
||||||
|
size = x_.size()
|
||||||
|
x_remainder = x_remainder.view(size + (1,))
|
||||||
|
|
||||||
|
translation = torch.arange(n_approx) - n_approx // 2
|
||||||
|
translation = translation.to(device=x.device, dtype=x.dtype)
|
||||||
|
translation = translation.view([1] * len(size) + [len(translation)])
|
||||||
|
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||||
|
|
||||||
|
return out + rounded - (n_approx // 2)
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothRound(Base):
|
||||||
|
def __init__(self):
|
||||||
|
self._temperature = None
|
||||||
|
self._n_approx = None
|
||||||
|
super().__init__()
|
||||||
|
self.hard_round = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self):
|
||||||
|
return self._temperature
|
||||||
|
|
||||||
|
@temperature.setter
|
||||||
|
def temperature(self, value):
|
||||||
|
self._temperature = value
|
||||||
|
|
||||||
|
if self._temperature <= 0.05:
|
||||||
|
self._n_approx = 1
|
||||||
|
elif 0.05 < self._temperature < 0.13:
|
||||||
|
self._n_approx = 3
|
||||||
|
else:
|
||||||
|
self._n_approx = 5
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert self._temperature is not None
|
||||||
|
assert self._n_approx is not None
|
||||||
|
assert self.hard_round is not None
|
||||||
|
|
||||||
|
if self.temperature <= 0.25:
|
||||||
|
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
|
||||||
|
else:
|
||||||
|
h = x
|
||||||
|
|
||||||
|
if self.hard_round:
|
||||||
|
h = _round_straightthrough(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticRound(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.hard_round = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
u = torch.rand_like(x)
|
||||||
|
|
||||||
|
h = x + u - 0.5
|
||||||
|
|
||||||
|
if self.hard_round:
|
||||||
|
h = _round_straightthrough(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class BackRound(Base):
|
||||||
|
|
||||||
|
def __init__(self, args, inverse_bin_width):
|
||||||
|
"""
|
||||||
|
BackRound is an approximation to Round that allows for Backpropagation.
|
||||||
|
|
||||||
|
Approximate the round function using a sum of translated sigmoids.
|
||||||
|
The temperature determines how well the round function is approximated,
|
||||||
|
i.e., a lower temperature corresponds to a better approximation, at
|
||||||
|
the cost of more vanishing gradients.
|
||||||
|
|
||||||
|
BackRound supports the following settings:
|
||||||
|
* By setting hard to True and temperature > 0.25, BackRound
|
||||||
|
reduces to a round function with a straight through gradient
|
||||||
|
estimator
|
||||||
|
* When using 0 < temperature <= 0.25 and hard = True, the
|
||||||
|
output in the forward pass is equivalent to a round function, but the
|
||||||
|
gradient is approximated by the gradient of a sum of sigmoids.
|
||||||
|
* When using hard = False, the output is not constrained to integers.
|
||||||
|
* When temperature > 0.25 and hard = False, BackRound reduces to
|
||||||
|
the identity function.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
temperature: float
|
||||||
|
Temperature used for stacked sigmoid approximated. If temperature
|
||||||
|
is greater than 0.25, the approximation reduces to the indentiy
|
||||||
|
function.
|
||||||
|
hard: bool
|
||||||
|
If hard is True, a (hard) round is applied before returning. The
|
||||||
|
gradient for this is approximated using the straight-through
|
||||||
|
estimator.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.inverse_bin_width = inverse_bin_width
|
||||||
|
self.round_approx = args.round_approx
|
||||||
|
|
||||||
|
if args.round_approx == 'smooth':
|
||||||
|
self.round = SmoothRound()
|
||||||
|
elif args.round_approx == 'stochastic':
|
||||||
|
self.round = StochasticRound()
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
|
||||||
|
h = x * self.inverse_bin_width
|
||||||
|
|
||||||
|
h = self.round(h)
|
||||||
|
|
||||||
|
return h / self.inverse_bin_width
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
142
integer_discrete_flows/models/coupling.py
Normal file
142
integer_discrete_flows/models/coupling.py
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
"""
|
||||||
|
Collection of flow strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from models.utils import Base
|
||||||
|
from .backround import BackRound
|
||||||
|
from .networks import NN
|
||||||
|
|
||||||
|
|
||||||
|
UNIT_TESTING = False
|
||||||
|
|
||||||
|
|
||||||
|
class SplitFactorCoupling(Base):
|
||||||
|
def __init__(self, c_in, factor, height, width, args):
|
||||||
|
super().__init__()
|
||||||
|
self.n_channels = args.n_channels
|
||||||
|
self.kernel = 3
|
||||||
|
self.input_channel = c_in
|
||||||
|
self.round_approx = args.round_approx
|
||||||
|
|
||||||
|
if args.variable_type == 'discrete':
|
||||||
|
self.round = BackRound(
|
||||||
|
args, inverse_bin_width=2**args.n_bits)
|
||||||
|
else:
|
||||||
|
self.round = None
|
||||||
|
|
||||||
|
self.split_idx = c_in - (c_in // factor)
|
||||||
|
|
||||||
|
self.nn = NN(
|
||||||
|
args=args,
|
||||||
|
c_in=self.split_idx,
|
||||||
|
c_out=c_in - self.split_idx,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
kernel=self.kernel,
|
||||||
|
nn_type=args.coupling_type)
|
||||||
|
|
||||||
|
def forward(self, z, ldj, reverse=False):
|
||||||
|
z1 = z[:, :self.split_idx, :, :]
|
||||||
|
z2 = z[:, self.split_idx:, :, :]
|
||||||
|
|
||||||
|
t = self.nn(z1)
|
||||||
|
|
||||||
|
if self.round is not None:
|
||||||
|
t = self.round(t)
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
z2 = z2 + t
|
||||||
|
else:
|
||||||
|
z2 = z2 - t
|
||||||
|
|
||||||
|
z = torch.cat([z1, z2], dim=1)
|
||||||
|
|
||||||
|
return z, ldj
|
||||||
|
|
||||||
|
|
||||||
|
class Coupling(Base):
|
||||||
|
def __init__(self, c_in, height, width, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if args.split_quarter:
|
||||||
|
factor = 4
|
||||||
|
elif args.splitfactor > 1:
|
||||||
|
factor = args.splitfactor
|
||||||
|
else:
|
||||||
|
factor = 2
|
||||||
|
|
||||||
|
self.coupling = SplitFactorCoupling(
|
||||||
|
c_in, factor, height, width, args=args)
|
||||||
|
|
||||||
|
def forward(self, z, ldj, reverse=False):
|
||||||
|
return self.coupling(z, ldj, reverse)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generative_flow():
|
||||||
|
import models.networks as networks
|
||||||
|
global UNIT_TESTING
|
||||||
|
|
||||||
|
networks.UNIT_TESTING = True
|
||||||
|
UNIT_TESTING = True
|
||||||
|
|
||||||
|
batch_size = 17
|
||||||
|
|
||||||
|
input_size = [12, 16, 16]
|
||||||
|
|
||||||
|
class Args():
|
||||||
|
def __init__(self):
|
||||||
|
self.input_size = input_size
|
||||||
|
self.learn_split = False
|
||||||
|
self.variable_type = 'continuous'
|
||||||
|
self.distribution_type = 'logistic'
|
||||||
|
self.round_approx = 'smooth'
|
||||||
|
self.coupling_type = 'shallow'
|
||||||
|
self.conv_type = 'standard'
|
||||||
|
self.densenet_depth = 8
|
||||||
|
self.bottleneck = False
|
||||||
|
self.n_channels = 512
|
||||||
|
self.network1x1 = 'standard'
|
||||||
|
self.auxilary_freq = -1
|
||||||
|
self.actnorm = False
|
||||||
|
self.LU = False
|
||||||
|
self.coupling_lifting_L = True
|
||||||
|
self.splitprior = True
|
||||||
|
self.split_quarter = True
|
||||||
|
self.n_levels = 2
|
||||||
|
self.n_flows = 2
|
||||||
|
self.cond_L = True
|
||||||
|
self.n_bits = True
|
||||||
|
|
||||||
|
args = Args()
|
||||||
|
|
||||||
|
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
|
||||||
|
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||||
|
|
||||||
|
model = Coupling(c_in=12, height=16, width=16, args=args)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
model.set_temperature(1.)
|
||||||
|
model.enable_hard_round()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
z, ldj = model(x, ldj, reverse=False)
|
||||||
|
|
||||||
|
# Check if gradient computation works
|
||||||
|
loss = torch.sum(z**2)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
recon, ldj = model(z, ldj, reverse=True)
|
||||||
|
|
||||||
|
sse = torch.sum(torch.pow(x - recon, 2)).item()
|
||||||
|
ae = torch.abs(x - recon).sum()
|
||||||
|
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_generative_flow()
|
||||||
175
integer_discrete_flows/models/generative_flows.py
Normal file
175
integer_discrete_flows/models/generative_flows.py
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""
|
||||||
|
Collection of flow strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from models.utils import Base
|
||||||
|
from .priors import SplitPrior
|
||||||
|
from .coupling import Coupling
|
||||||
|
|
||||||
|
|
||||||
|
UNIT_TESTING = False
|
||||||
|
|
||||||
|
|
||||||
|
def space_to_depth(x):
|
||||||
|
xs = x.size()
|
||||||
|
# Pick off every second element
|
||||||
|
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
||||||
|
# Transpose picked elements next to channels.
|
||||||
|
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
||||||
|
# Combine with channels.
|
||||||
|
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def depth_to_space(x):
|
||||||
|
xs = x.size()
|
||||||
|
# Pick off elements from channels
|
||||||
|
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
||||||
|
# Transpose picked elements next to HW dimensions.
|
||||||
|
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
||||||
|
# Combine with HW dimensions.
|
||||||
|
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def int_shape(x):
|
||||||
|
return list(map(int, x.size()))
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(Base):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Reshape(Base):
|
||||||
|
def __init__(self, shape):
|
||||||
|
super().__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), *self.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Reverse(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, z, reverse=False):
|
||||||
|
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
||||||
|
z = z[:, flip_idx, :, :]
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class Permute(Base):
|
||||||
|
def __init__(self, n_channels):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
permutation = np.arange(n_channels, dtype='int')
|
||||||
|
np.random.shuffle(permutation)
|
||||||
|
|
||||||
|
permutation_inv = np.zeros(n_channels, dtype='int')
|
||||||
|
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
||||||
|
|
||||||
|
self.permutation = torch.from_numpy(permutation)
|
||||||
|
self.permutation_inv = torch.from_numpy(permutation_inv)
|
||||||
|
|
||||||
|
def forward(self, z, ldj, reverse=False):
|
||||||
|
if not reverse:
|
||||||
|
z = z[:, self.permutation, :, :]
|
||||||
|
else:
|
||||||
|
z = z[:, self.permutation_inv, :, :]
|
||||||
|
|
||||||
|
return z, ldj
|
||||||
|
|
||||||
|
def InversePermute(self):
|
||||||
|
inv_permute = Permute(len(self.permutation))
|
||||||
|
inv_permute.permutation = self.permutation_inv
|
||||||
|
inv_permute.permutation_inv = self.permutation
|
||||||
|
return inv_permute
|
||||||
|
|
||||||
|
|
||||||
|
class Squeeze(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, z, ldj, reverse=False):
|
||||||
|
if not reverse:
|
||||||
|
z = space_to_depth(z)
|
||||||
|
else:
|
||||||
|
z = depth_to_space(z)
|
||||||
|
return z, ldj
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeFlow(Base):
|
||||||
|
def __init__(self, n_channels, height, width, args):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
layers.append(Squeeze())
|
||||||
|
n_channels *= 4
|
||||||
|
height //= 2
|
||||||
|
width //= 2
|
||||||
|
|
||||||
|
for level in range(args.n_levels):
|
||||||
|
|
||||||
|
for i in range(args.n_flows):
|
||||||
|
perm_layer = Permute(n_channels)
|
||||||
|
layers.append(perm_layer)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
Coupling(n_channels, height, width, args))
|
||||||
|
|
||||||
|
if level < args.n_levels - 1:
|
||||||
|
if args.splitprior_type != 'none':
|
||||||
|
# Standard splitprior
|
||||||
|
factor_out = n_channels // 2
|
||||||
|
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
||||||
|
n_channels = n_channels - factor_out
|
||||||
|
|
||||||
|
layers.append(Squeeze())
|
||||||
|
n_channels *= 4
|
||||||
|
height //= 2
|
||||||
|
width //= 2
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
self.z_size = (n_channels, height, width)
|
||||||
|
|
||||||
|
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
||||||
|
if not reverse:
|
||||||
|
for l, layer in enumerate(self.layers):
|
||||||
|
if isinstance(layer, (SplitPrior)):
|
||||||
|
py, y, z, ldj = layer(z, ldj)
|
||||||
|
pys += (py,)
|
||||||
|
ys += (y,)
|
||||||
|
|
||||||
|
else:
|
||||||
|
z, ldj = layer(z, ldj)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for l, layer in reversed(list(enumerate(self.layers))):
|
||||||
|
if isinstance(layer, (SplitPrior)):
|
||||||
|
if len(ys) > 0:
|
||||||
|
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
||||||
|
# Pop last element
|
||||||
|
ys = ys[:-1]
|
||||||
|
else:
|
||||||
|
z, ldj = layer.inverse(z, ldj, y=None)
|
||||||
|
|
||||||
|
else:
|
||||||
|
z, ldj = layer(z, ldj, reverse=True)
|
||||||
|
|
||||||
|
return z, ldj, pys, ys
|
||||||
|
|
||||||
|
def decode(self, z, ldj, state, decode_fn):
|
||||||
|
|
||||||
|
for l, layer in reversed(list(enumerate(self.layers))):
|
||||||
|
if isinstance(layer, SplitPrior):
|
||||||
|
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
z, ldj = layer(z, ldj, reverse=True)
|
||||||
|
|
||||||
|
return z, ldj
|
||||||
154
integer_discrete_flows/models/networks.py
Normal file
154
integer_discrete_flows/models/networks.py
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
"""
|
||||||
|
Collection of flow strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from models.utils import Base
|
||||||
|
|
||||||
|
|
||||||
|
UNIT_TESTING = False
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dReLU(Base):
|
||||||
|
def __init__(
|
||||||
|
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
|
||||||
|
bias=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.nn(x)
|
||||||
|
|
||||||
|
y = F.relu(h)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(Base):
|
||||||
|
def __init__(self, n_channels, kernel, Conv2dAct):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.nn = torch.nn.Sequential(
|
||||||
|
Conv2dAct(n_channels, n_channels, kernel, padding=1),
|
||||||
|
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.nn(x)
|
||||||
|
h = F.relu(h + x)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DenseLayer(Base):
|
||||||
|
def __init__(self, args, n_inputs, growth, Conv2dAct):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv1x1 = Conv2dAct(
|
||||||
|
n_inputs, n_inputs, kernel_size=1, stride=1,
|
||||||
|
padding=0, bias=True)
|
||||||
|
|
||||||
|
self.nn = torch.nn.Sequential(
|
||||||
|
conv1x1,
|
||||||
|
Conv2dAct(
|
||||||
|
n_inputs, growth, kernel_size=3, stride=1,
|
||||||
|
padding=1, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.nn(x)
|
||||||
|
|
||||||
|
h = torch.cat([x, h], dim=1)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(Base):
|
||||||
|
def __init__(
|
||||||
|
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
|
||||||
|
super().__init__()
|
||||||
|
depth = args.densenet_depth
|
||||||
|
|
||||||
|
future_growth = n_outputs - n_inputs
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for d in range(depth):
|
||||||
|
growth = future_growth // (depth - d)
|
||||||
|
|
||||||
|
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
|
||||||
|
n_inputs += growth
|
||||||
|
future_growth -= growth
|
||||||
|
|
||||||
|
self.nn = torch.nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.nn(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super.__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NN(Base):
|
||||||
|
def __init__(
|
||||||
|
self, args, c_in, c_out, height, width, nn_type, kernel=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
Conv2dAct = Conv2dReLU
|
||||||
|
n_channels = args.n_channels
|
||||||
|
|
||||||
|
if nn_type == 'shallow':
|
||||||
|
|
||||||
|
if args.network1x1 == 'standard':
|
||||||
|
conv1x1 = Conv2dAct(
|
||||||
|
n_channels, n_channels, kernel_size=1, stride=1,
|
||||||
|
padding=0, bias=False)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||||
|
conv1x1]
|
||||||
|
|
||||||
|
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
|
||||||
|
|
||||||
|
elif nn_type == 'resnet':
|
||||||
|
layers = [
|
||||||
|
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||||
|
ResidualBlock(n_channels, kernel, Conv2dAct),
|
||||||
|
ResidualBlock(n_channels, kernel, Conv2dAct)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
elif nn_type == 'densenet':
|
||||||
|
layers = [
|
||||||
|
DenseBlock(
|
||||||
|
args=args,
|
||||||
|
n_inputs=c_in,
|
||||||
|
n_outputs=n_channels + c_in,
|
||||||
|
kernel=kernel,
|
||||||
|
Conv2dAct=Conv2dAct)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self.nn = torch.nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# Set parameters of last conv-layer to zero.
|
||||||
|
if not UNIT_TESTING:
|
||||||
|
self.nn[-1].weight.data.zero_()
|
||||||
|
self.nn[-1].bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.nn(x)
|
||||||
164
integer_discrete_flows/models/priors.py
Normal file
164
integer_discrete_flows/models/priors.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""
|
||||||
|
Collection of flow strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from utils.distributions import sample_discretized_logistic, \
|
||||||
|
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
|
||||||
|
sample_discretized_normal, sample_mixture_normal
|
||||||
|
from models.utils import Base
|
||||||
|
from .networks import NN
|
||||||
|
|
||||||
|
|
||||||
|
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
|
||||||
|
if variable_type == 'discrete':
|
||||||
|
if distribution_type == 'logistic':
|
||||||
|
if len(px) == 2:
|
||||||
|
return sample_discretized_logistic(
|
||||||
|
*px, inverse_bin_width=inverse_bin_width)
|
||||||
|
elif len(px) == 3:
|
||||||
|
return sample_mixture_discretized_logistic(
|
||||||
|
*px, inverse_bin_width=inverse_bin_width)
|
||||||
|
|
||||||
|
elif distribution_type == 'normal':
|
||||||
|
return sample_discretized_normal(
|
||||||
|
*px, inverse_bin_width=inverse_bin_width)
|
||||||
|
|
||||||
|
elif variable_type == 'continuous':
|
||||||
|
if distribution_type == 'logistic':
|
||||||
|
return sample_logistic(*px)
|
||||||
|
elif distribution_type == 'normal':
|
||||||
|
if len(px) == 2:
|
||||||
|
return sample_normal(*px)
|
||||||
|
elif len(px) == 3:
|
||||||
|
return sample_mixture_normal(*px)
|
||||||
|
elif distribution_type == 'steplogistic':
|
||||||
|
return sample_logistic(*px)
|
||||||
|
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
class Prior(Base):
|
||||||
|
def __init__(self, size, args):
|
||||||
|
super().__init__()
|
||||||
|
c, h, w = size
|
||||||
|
|
||||||
|
self.inverse_bin_width = 2**args.n_bits
|
||||||
|
self.variable_type = args.variable_type
|
||||||
|
self.distribution_type = args.distribution_type
|
||||||
|
self.n_mixtures = args.n_mixtures
|
||||||
|
|
||||||
|
if self.n_mixtures == 1:
|
||||||
|
self.mu = Parameter(torch.Tensor(c, h, w))
|
||||||
|
self.logs = Parameter(torch.Tensor(c, h, w))
|
||||||
|
elif self.n_mixtures > 1:
|
||||||
|
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||||
|
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||||
|
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.mu.data.zero_()
|
||||||
|
|
||||||
|
if self.n_mixtures > 1:
|
||||||
|
self.pi_logit.data.zero_()
|
||||||
|
for i in range(self.n_mixtures):
|
||||||
|
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
|
||||||
|
|
||||||
|
self.logs.data.zero_()
|
||||||
|
|
||||||
|
def get_pz(self, n):
|
||||||
|
if self.n_mixtures == 1:
|
||||||
|
mu = self.mu.repeat(n, 1, 1, 1)
|
||||||
|
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
|
||||||
|
return mu, logs
|
||||||
|
|
||||||
|
elif self.n_mixtures > 1:
|
||||||
|
pi = F.softmax(self.pi_logit, dim=-1)
|
||||||
|
mu = self.mu.repeat(n, 1, 1, 1, 1)
|
||||||
|
logs = self.logs.repeat(n, 1, 1, 1, 1)
|
||||||
|
pi = pi.repeat(n, 1, 1, 1, 1)
|
||||||
|
return mu, logs, pi
|
||||||
|
|
||||||
|
def forward(self, z, ldj):
|
||||||
|
pz = self.get_pz(z.size(0))
|
||||||
|
|
||||||
|
return pz, z, ldj
|
||||||
|
|
||||||
|
def sample(self, n):
|
||||||
|
pz = self.get_pz(n)
|
||||||
|
|
||||||
|
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||||
|
|
||||||
|
return z_sample
|
||||||
|
|
||||||
|
def decode(self, states, decode_fn):
|
||||||
|
pz = self.get_pz(n=len(states))
|
||||||
|
|
||||||
|
states, z = decode_fn(states, pz)
|
||||||
|
return states, z
|
||||||
|
|
||||||
|
|
||||||
|
class SplitPrior(Base):
|
||||||
|
def __init__(self, c_in, factor_out, height, width, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.split_idx = c_in - factor_out
|
||||||
|
self.inverse_bin_width = 2**args.n_bits
|
||||||
|
self.variable_type = args.variable_type
|
||||||
|
self.distribution_type = args.distribution_type
|
||||||
|
self.input_channel = c_in
|
||||||
|
|
||||||
|
self.nn = NN(
|
||||||
|
args=args,
|
||||||
|
c_in=c_in - factor_out,
|
||||||
|
c_out=factor_out * 2,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
nn_type=args.splitprior_type)
|
||||||
|
|
||||||
|
def get_py(self, z):
|
||||||
|
h = self.nn(z)
|
||||||
|
mu = h[:, ::2, :, :]
|
||||||
|
logs = h[:, 1::2, :, :]
|
||||||
|
|
||||||
|
py = [mu, logs]
|
||||||
|
|
||||||
|
return py
|
||||||
|
|
||||||
|
def split(self, z):
|
||||||
|
z1 = z[:, :self.split_idx, :, :]
|
||||||
|
y = z[:, self.split_idx:, :, :]
|
||||||
|
return z1, y
|
||||||
|
|
||||||
|
def combine(self, z, y):
|
||||||
|
result = torch.cat([z, y], dim=1)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward(self, z, ldj):
|
||||||
|
z, y = self.split(z)
|
||||||
|
|
||||||
|
py = self.get_py(z)
|
||||||
|
|
||||||
|
return py, y, z, ldj
|
||||||
|
|
||||||
|
def inverse(self, z, ldj, y):
|
||||||
|
# Sample if y is not given.
|
||||||
|
if y is None:
|
||||||
|
py = self.get_py(z)
|
||||||
|
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||||
|
|
||||||
|
z = self.combine(z, y)
|
||||||
|
|
||||||
|
return z, ldj
|
||||||
|
|
||||||
|
def decode(self, z, ldj, states, decode_fn):
|
||||||
|
py = self.get_py(z)
|
||||||
|
states, y = decode_fn(states, py)
|
||||||
|
return self.combine(z, y), ldj, states
|
||||||
36
integer_discrete_flows/models/utils.py
Normal file
36
integer_discrete_flows/models/utils.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Base(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
The base class for modules. That contains a disable round mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _set_child_attribute(self, attr, value):
|
||||||
|
r"""Sets the module in rounding mode.
|
||||||
|
|
||||||
|
This has any effect only on certain modules if variable type is
|
||||||
|
discrete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Module: self
|
||||||
|
"""
|
||||||
|
if hasattr(self, attr):
|
||||||
|
setattr(self, attr, value)
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if hasattr(module, attr):
|
||||||
|
setattr(module, attr, value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_temperature(self, value):
|
||||||
|
self._set_child_attribute("temperature", value)
|
||||||
|
|
||||||
|
def enable_hard_round(self, mode=True):
|
||||||
|
self._set_child_attribute("hard_round", mode)
|
||||||
|
|
||||||
|
def disable_hard_round(self, mode=True):
|
||||||
|
self.enable_hard_round(not mode)
|
||||||
0
integer_discrete_flows/optimization/__init__.py
Normal file
0
integer_discrete_flows/optimization/__init__.py
Normal file
148
integer_discrete_flows/optimization/loss.py
Normal file
148
integer_discrete_flows/optimization/loss.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from utils.distributions import log_discretized_logistic, \
|
||||||
|
log_mixture_discretized_logistic, log_normal, log_discretized_normal, \
|
||||||
|
log_logistic, log_mixture_normal
|
||||||
|
from models.backround import _round_straightthrough
|
||||||
|
|
||||||
|
|
||||||
|
def compute_log_ps(pxs, xs, args):
|
||||||
|
# Add likelihoods of intermediate representations.
|
||||||
|
inverse_bin_width = 2.**args.n_bits
|
||||||
|
|
||||||
|
log_pxs = []
|
||||||
|
for px, x in zip(pxs, xs):
|
||||||
|
|
||||||
|
if args.variable_type == 'discrete':
|
||||||
|
if args.distribution_type == 'logistic':
|
||||||
|
log_px = log_discretized_logistic(
|
||||||
|
x, *px, inverse_bin_width=inverse_bin_width)
|
||||||
|
elif args.distribution_type == 'normal':
|
||||||
|
log_px = log_discretized_normal(
|
||||||
|
x, *px, inverse_bin_width=inverse_bin_width)
|
||||||
|
elif args.variable_type == 'continuous':
|
||||||
|
if args.distribution_type == 'logistic':
|
||||||
|
log_px = log_logistic(x, *px)
|
||||||
|
elif args.distribution_type == 'normal':
|
||||||
|
log_px = log_normal(x, *px)
|
||||||
|
elif args.distribution_type == 'steplogistic':
|
||||||
|
x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width
|
||||||
|
log_px = log_discretized_logistic(
|
||||||
|
x, *px, inverse_bin_width=inverse_bin_width)
|
||||||
|
|
||||||
|
log_pxs.append(
|
||||||
|
torch.sum(log_px, dim=[1, 2, 3]))
|
||||||
|
|
||||||
|
return log_pxs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_log_pz(pz, z, args):
|
||||||
|
inverse_bin_width = 2.**args.n_bits
|
||||||
|
|
||||||
|
if args.variable_type == 'discrete':
|
||||||
|
if args.distribution_type == 'logistic':
|
||||||
|
if args.n_mixtures == 1:
|
||||||
|
log_pz = log_discretized_logistic(
|
||||||
|
z, pz[0], pz[1], inverse_bin_width=inverse_bin_width)
|
||||||
|
else:
|
||||||
|
log_pz = log_mixture_discretized_logistic(
|
||||||
|
z, pz[0], pz[1], pz[2],
|
||||||
|
inverse_bin_width=inverse_bin_width)
|
||||||
|
elif args.distribution_type == 'normal':
|
||||||
|
log_pz = log_discretized_normal(
|
||||||
|
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||||
|
|
||||||
|
elif args.variable_type == 'continuous':
|
||||||
|
if args.distribution_type == 'logistic':
|
||||||
|
log_pz = log_logistic(z, *pz)
|
||||||
|
elif args.distribution_type == 'normal':
|
||||||
|
if args.n_mixtures == 1:
|
||||||
|
log_pz = log_normal(z, *pz)
|
||||||
|
else:
|
||||||
|
log_pz = log_mixture_normal(z, *pz)
|
||||||
|
elif args.distribution_type == 'steplogistic':
|
||||||
|
z = _round_straightthrough(z * 256.) / 256.
|
||||||
|
log_pz = log_discretized_logistic(z, *pz)
|
||||||
|
|
||||||
|
log_pz = torch.sum(
|
||||||
|
log_pz,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
|
||||||
|
return log_pz
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss_function(pz, z, pys, ys, ldj, args):
|
||||||
|
"""
|
||||||
|
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
||||||
|
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
||||||
|
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
||||||
|
:param z_mu: mean of z_0
|
||||||
|
:param z_var: variance of z_0
|
||||||
|
:param z_0: first stochastic latent variable
|
||||||
|
:param z_k: last stochastic latent variable
|
||||||
|
:param ldj: log det jacobian
|
||||||
|
:param args: global parameter settings
|
||||||
|
:param beta: beta for kl loss
|
||||||
|
:return: loss, ce, kl
|
||||||
|
"""
|
||||||
|
batch_size = z.size(0)
|
||||||
|
|
||||||
|
# Get array loss, sum over batch
|
||||||
|
loss_array, bpd_array, bpd_per_prior_array = \
|
||||||
|
compute_loss_array(pz, z, pys, ys, ldj, args)
|
||||||
|
|
||||||
|
loss = torch.mean(loss_array)
|
||||||
|
bpd = torch.mean(bpd_array).item()
|
||||||
|
bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array]
|
||||||
|
|
||||||
|
return loss, bpd, bpd_per_prior
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bpd(log_p, input_size):
|
||||||
|
return -log_p / (np.prod(input_size) * np.log(2.))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss_array(pz, z, pys, ys, ldj, args):
|
||||||
|
"""
|
||||||
|
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
||||||
|
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
||||||
|
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
||||||
|
:param z_mu: mean of z_0
|
||||||
|
:param z_var: variance of z_0
|
||||||
|
:param z_0: first stochastic latent variable
|
||||||
|
:param z_k: last stochastic latent variable
|
||||||
|
:param ldj: log det jacobian
|
||||||
|
:param args: global parameter settings
|
||||||
|
:param beta: beta for kl loss
|
||||||
|
:return: loss, ce, kl
|
||||||
|
"""
|
||||||
|
bpd_per_prior = []
|
||||||
|
|
||||||
|
# Likelihood of final representation.
|
||||||
|
log_pz = compute_log_pz(pz, z, args)
|
||||||
|
|
||||||
|
bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size))
|
||||||
|
|
||||||
|
log_p = log_pz
|
||||||
|
|
||||||
|
# Add likelihoods of intermediate representations.
|
||||||
|
if ys:
|
||||||
|
log_pys = compute_log_ps(pys, ys, args)
|
||||||
|
|
||||||
|
for log_py in log_pys:
|
||||||
|
log_p += log_py
|
||||||
|
|
||||||
|
bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size))
|
||||||
|
|
||||||
|
log_p += ldj
|
||||||
|
|
||||||
|
loss = -log_p
|
||||||
|
bpd = convert_bpd(log_p.detach(), args.input_size)
|
||||||
|
|
||||||
|
return loss, bpd, bpd_per_prior
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args):
|
||||||
|
return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)
|
||||||
174
integer_discrete_flows/optimization/training.py
Normal file
174
integer_discrete_flows/optimization/training.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from optimization.loss import calculate_loss
|
||||||
|
from utils.visual_evaluation import plot_reconstructions
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def train(epoch, train_loader, model, opt, args):
|
||||||
|
model.train()
|
||||||
|
train_loss = np.zeros(len(train_loader))
|
||||||
|
train_bpd = np.zeros(len(train_loader))
|
||||||
|
|
||||||
|
num_data = 0
|
||||||
|
|
||||||
|
for batch_idx, (data, _) in enumerate(train_loader):
|
||||||
|
data = data.view(-1, *args.input_size)
|
||||||
|
|
||||||
|
data = data.to(args.device)
|
||||||
|
|
||||||
|
opt.zero_grad()
|
||||||
|
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
||||||
|
|
||||||
|
loss = torch.mean(loss)
|
||||||
|
bpd = torch.mean(bpd)
|
||||||
|
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
loss = loss.item()
|
||||||
|
train_loss[batch_idx] = loss
|
||||||
|
train_bpd[batch_idx] = bpd
|
||||||
|
|
||||||
|
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
|
||||||
|
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
num_data += len(data)
|
||||||
|
|
||||||
|
if batch_idx % args.log_interval == 0:
|
||||||
|
perc = 100. * batch_idx / len(train_loader)
|
||||||
|
|
||||||
|
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
|
||||||
|
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
|
||||||
|
|
||||||
|
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
|
||||||
|
|
||||||
|
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
|
||||||
|
for i in range(1, len(bpd_per_prior)):
|
||||||
|
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
|
||||||
|
|
||||||
|
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||||
|
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||||
|
if len(pz) == 3:
|
||||||
|
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||||
|
|
||||||
|
for i, py in enumerate(pys):
|
||||||
|
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||||
|
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||||
|
|
||||||
|
from utils.visual_evaluation import plot_images
|
||||||
|
import os
|
||||||
|
if not os.path.exists(args.snap_dir + 'training/'):
|
||||||
|
os.makedirs(args.snap_dir + 'training/')
|
||||||
|
|
||||||
|
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
|
||||||
|
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
|
||||||
|
|
||||||
|
return train_loss, train_bpd
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
loss_type = 'bpd'
|
||||||
|
|
||||||
|
def analyse(data_loader, plot=False):
|
||||||
|
bpds = []
|
||||||
|
batch_idx = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, _ in data_loader:
|
||||||
|
batch_idx += 1
|
||||||
|
|
||||||
|
if args.cuda:
|
||||||
|
data = data.cuda()
|
||||||
|
|
||||||
|
data = data.view(-1, *args.input_size)
|
||||||
|
|
||||||
|
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
|
||||||
|
model(data)
|
||||||
|
loss = torch.mean(loss).item()
|
||||||
|
batch_bpd = torch.mean(batch_bpd).item()
|
||||||
|
|
||||||
|
bpds.append(batch_bpd)
|
||||||
|
|
||||||
|
bpd = np.mean(bpds)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if not testing and plot:
|
||||||
|
x_sample = model_sample.sample(n=100)
|
||||||
|
|
||||||
|
try:
|
||||||
|
plot_reconstructions(
|
||||||
|
x_sample, bpd, loss_type, epoch, args)
|
||||||
|
except:
|
||||||
|
print('Not plotting')
|
||||||
|
|
||||||
|
return bpd
|
||||||
|
|
||||||
|
bpd_train = analyse(train_loader)
|
||||||
|
bpd_val = analyse(val_loader, plot=True)
|
||||||
|
|
||||||
|
with open(file, 'a') as ff:
|
||||||
|
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
|
||||||
|
epoch,
|
||||||
|
bpd_train,
|
||||||
|
bpd_val)
|
||||||
|
print(msg, file=ff)
|
||||||
|
|
||||||
|
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
|
||||||
|
bpd = bpd_val
|
||||||
|
|
||||||
|
file = None
|
||||||
|
|
||||||
|
# Compute log-likelihood
|
||||||
|
with torch.no_grad():
|
||||||
|
if testing:
|
||||||
|
test_data = val_loader.dataset.data_tensor
|
||||||
|
|
||||||
|
if args.cuda:
|
||||||
|
test_data = test_data.cuda()
|
||||||
|
|
||||||
|
print('Computing log-likelihood on test set')
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
log_likelihood = analyse(test_data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
log_likelihood = None
|
||||||
|
nll_bpd = None
|
||||||
|
|
||||||
|
if file is None:
|
||||||
|
if testing:
|
||||||
|
print('====> Test set loss: {:.4f}'.format(loss))
|
||||||
|
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
|
||||||
|
|
||||||
|
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
|
||||||
|
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
|
||||||
|
(np.prod(args.input_size) * np.log(2.))))
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('====> Validation set loss: {:.4f}'.format(loss))
|
||||||
|
print('====> Validation set bpd: {:.4f}'.format(bpd))
|
||||||
|
else:
|
||||||
|
with open(file, 'a') as ff:
|
||||||
|
if testing:
|
||||||
|
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
|
||||||
|
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
|
||||||
|
|
||||||
|
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
|
||||||
|
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
|
||||||
|
(np.prod(args.input_size) * np.log(2.))),
|
||||||
|
file=ff)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
|
||||||
|
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
|
||||||
|
file=ff)
|
||||||
|
|
||||||
|
if not testing:
|
||||||
|
return loss, bpd
|
||||||
|
else:
|
||||||
|
return log_likelihood, nll_bpd
|
||||||
0
integer_discrete_flows/utils/__init__.py
Normal file
0
integer_discrete_flows/utils/__init__.py
Normal file
209
integer_discrete_flows/utils/distributions.py
Normal file
209
integer_discrete_flows/utils/distributions.py
Normal file
|
|
@ -0,0 +1,209 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
MIN_EPSILON = 1e-5
|
||||||
|
MAX_EPSILON = 1.-1e-5
|
||||||
|
|
||||||
|
|
||||||
|
PI = math.pi
|
||||||
|
|
||||||
|
|
||||||
|
def log_min_exp(a, b, epsilon=1e-8):
|
||||||
|
"""
|
||||||
|
Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.
|
||||||
|
Using:
|
||||||
|
log(exp(a) - exp(b))
|
||||||
|
c + log(exp(a-c) - exp(b-c))
|
||||||
|
a + log(1 - exp(b-a))
|
||||||
|
And note that we assume b < a always.
|
||||||
|
"""
|
||||||
|
y = a + torch.log(1 - torch.exp(b - a) + epsilon)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def log_normal(x, mean, logvar):
|
||||||
|
logp = -0.5 * logvar
|
||||||
|
logp += -0.5 * np.log(2 * PI)
|
||||||
|
logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar)
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def log_mixture_normal(x, mean, logvar, pi):
|
||||||
|
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||||
|
|
||||||
|
logp_mixtures = log_normal(x, mean, logvar)
|
||||||
|
|
||||||
|
logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8)
|
||||||
|
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def sample_normal(mean, logvar):
|
||||||
|
y = torch.randn_like(mean)
|
||||||
|
|
||||||
|
x = torch.exp(0.5 * logvar) * y + mean
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sample_mixture_normal(mean, logvar, pi):
|
||||||
|
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||||
|
pi = pi.view(b * c * h * w, n_mixtures)
|
||||||
|
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||||
|
|
||||||
|
# Select mixture params
|
||||||
|
mean = mean.view(b * c * h * w, n_mixtures)
|
||||||
|
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||||
|
logvar = logvar.view(b * c * h * w, n_mixtures)
|
||||||
|
logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||||
|
|
||||||
|
y = sample_normal(mean, logvar)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def log_logistic(x, mean, logscale):
|
||||||
|
"""
|
||||||
|
pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale
|
||||||
|
"""
|
||||||
|
scale = torch.exp(logscale)
|
||||||
|
|
||||||
|
u = (x - mean) / scale
|
||||||
|
|
||||||
|
logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale
|
||||||
|
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def sample_logistic(mean, logscale):
|
||||||
|
y = torch.rand_like(mean)
|
||||||
|
|
||||||
|
x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
|
||||||
|
scale = torch.exp(logscale)
|
||||||
|
|
||||||
|
logp = log_min_exp(
|
||||||
|
F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
|
||||||
|
F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
|
||||||
|
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width):
|
||||||
|
scale = torch.exp(logscale)
|
||||||
|
|
||||||
|
cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||||
|
|
||||||
|
return cdf
|
||||||
|
|
||||||
|
|
||||||
|
def sample_discretized_logistic(mean, logscale, inverse_bin_width):
|
||||||
|
x = sample_logistic(mean, logscale)
|
||||||
|
|
||||||
|
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def normal_cdf(value, loc, std):
|
||||||
|
return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2)))
|
||||||
|
|
||||||
|
|
||||||
|
def log_discretized_normal(x, mean, logvar, inverse_bin_width):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7)
|
||||||
|
|
||||||
|
return log_p
|
||||||
|
|
||||||
|
|
||||||
|
def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
|
||||||
|
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||||
|
|
||||||
|
p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std)
|
||||||
|
|
||||||
|
p = torch.sum(p * pi, dim=-1)
|
||||||
|
|
||||||
|
logp = torch.log(p + 1e-8)
|
||||||
|
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def sample_discretized_normal(mean, logvar, inverse_bin_width):
|
||||||
|
y = torch.randn_like(mean)
|
||||||
|
|
||||||
|
x = torch.exp(0.5 * logvar) * y + mean
|
||||||
|
|
||||||
|
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width):
|
||||||
|
scale = torch.exp(logscale)
|
||||||
|
|
||||||
|
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||||
|
|
||||||
|
p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \
|
||||||
|
- torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale)
|
||||||
|
|
||||||
|
p = torch.sum(p * pi, dim=-1)
|
||||||
|
|
||||||
|
logp = torch.log(p + 1e-8)
|
||||||
|
|
||||||
|
return logp
|
||||||
|
|
||||||
|
|
||||||
|
def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width):
|
||||||
|
scale = torch.exp(logscale)
|
||||||
|
|
||||||
|
x = x[..., None]
|
||||||
|
|
||||||
|
cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||||
|
|
||||||
|
cdf = torch.sum(cdfs * pi, dim=-1)
|
||||||
|
|
||||||
|
return cdf
|
||||||
|
|
||||||
|
|
||||||
|
def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width):
|
||||||
|
# Sample mixtures
|
||||||
|
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||||
|
pi = pi.view(b * c * h * w, n_mixtures)
|
||||||
|
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||||
|
|
||||||
|
# Select mixture params
|
||||||
|
mean = mean.view(b * c * h * w, n_mixtures)
|
||||||
|
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||||
|
logs = logs.view(b * c * h * w, n_mixtures)
|
||||||
|
logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||||
|
|
||||||
|
y = torch.rand_like(mean)
|
||||||
|
x = torch.exp(logs) * torch.log(y / (1 - y)) + mean
|
||||||
|
|
||||||
|
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log_multinomial(logits, targets):
|
||||||
|
return -F.cross_entropy(logits, targets, reduction='none')
|
||||||
|
|
||||||
|
|
||||||
|
def sample_multinomial(logits):
|
||||||
|
b, n_categories, c, h, w = logits.size()
|
||||||
|
logits = logits.permute(0, 2, 3, 4, 1)
|
||||||
|
p = F.softmax(logits, dim=-1)
|
||||||
|
p = p.view(b * c * h * w, n_categories)
|
||||||
|
x = torch.multinomial(p, num_samples=1).view(b, c, h, w)
|
||||||
|
return x
|
||||||
264
integer_discrete_flows/utils/load_data.py
Normal file
264
integer_discrete_flows/utils/load_data.py
Normal file
|
|
@ -0,0 +1,264 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data_utils
|
||||||
|
import pickle
|
||||||
|
from scipy.io import loadmat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms import functional as vf
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
from os.path import join
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensorNoNorm():
|
||||||
|
def __call__(self, X_i):
|
||||||
|
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class PadToMultiple(object):
|
||||||
|
def __init__(self, multiple, fill=0, padding_mode='constant'):
|
||||||
|
assert isinstance(multiple, numbers.Number)
|
||||||
|
assert isinstance(fill, (numbers.Number, str, tuple))
|
||||||
|
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
||||||
|
|
||||||
|
self.multiple = multiple
|
||||||
|
self.fill = fill
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be padded.
|
||||||
|
Returns:
|
||||||
|
PIL Image: Padded image.
|
||||||
|
"""
|
||||||
|
w, h = img.size
|
||||||
|
m = self.multiple
|
||||||
|
nw = (w // m + int((w % m) != 0)) * m
|
||||||
|
nh = (h // m + int((h % m) != 0)) * m
|
||||||
|
padw = nw - w
|
||||||
|
padh = nh - h
|
||||||
|
|
||||||
|
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
|
||||||
|
format(self.mulitple, self.fill, self.padding_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTensorDataset(Dataset):
|
||||||
|
"""Dataset wrapping tensors.
|
||||||
|
|
||||||
|
Each sample will be retrieved by indexing tensors along the first dimension.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
*tensors (Tensor): tensors that have the same size of the first dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *tensors, transform=None):
|
||||||
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
||||||
|
self.tensors = tensors
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
X, y = self.tensors
|
||||||
|
X_i, y_i, = X[index], y[index]
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
X_i = self.transform(X_i)
|
||||||
|
X_i = torch.from_numpy(np.array(X_i, copy=False))
|
||||||
|
X_i = X_i.permute(2, 0, 1)
|
||||||
|
|
||||||
|
return X_i, y_i
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.tensors[0].size(0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_cifar10(args, **kwargs):
|
||||||
|
# set args
|
||||||
|
args.input_size = [3, 32, 32]
|
||||||
|
args.input_type = 'continuous'
|
||||||
|
args.dynamic_binarization = False
|
||||||
|
|
||||||
|
from keras.datasets import cifar10
|
||||||
|
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||||
|
|
||||||
|
x_train = x_train.transpose(0, 3, 1, 2)
|
||||||
|
x_test = x_test.transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
if args.data_augmentation_level == 2:
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
|
||||||
|
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
||||||
|
transforms.CenterCrop(32)
|
||||||
|
])
|
||||||
|
elif args.data_augmentation_level == 1:
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
])
|
||||||
|
|
||||||
|
x_val = x_train[-10000:]
|
||||||
|
y_val = y_train[-10000:]
|
||||||
|
|
||||||
|
x_train = x_train[:-10000]
|
||||||
|
y_train = y_train[:-10000]
|
||||||
|
|
||||||
|
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
|
||||||
|
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||||
|
|
||||||
|
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
|
||||||
|
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||||
|
|
||||||
|
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
|
||||||
|
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||||
|
|
||||||
|
return train_loader, val_loader, test_loader, args
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tar(tarpath):
|
||||||
|
assert tarpath.endswith('.tar')
|
||||||
|
|
||||||
|
startdir = tarpath[:-4] + '/'
|
||||||
|
|
||||||
|
if os.path.exists(startdir):
|
||||||
|
return startdir
|
||||||
|
|
||||||
|
print('Extracting', tarpath)
|
||||||
|
|
||||||
|
with tarfile.open(name=tarpath) as tar:
|
||||||
|
t = 0
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
path = join(startdir, 'images{}'.format(t))
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
print(path)
|
||||||
|
|
||||||
|
for i in range(50000):
|
||||||
|
member = tar.next()
|
||||||
|
|
||||||
|
if member is None:
|
||||||
|
done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Skip directories
|
||||||
|
while member.isdir():
|
||||||
|
member = tar.next()
|
||||||
|
if member is None:
|
||||||
|
done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
member.name = member.name.split('/')[-1]
|
||||||
|
|
||||||
|
tar.extract(member, path=path)
|
||||||
|
|
||||||
|
t += 1
|
||||||
|
|
||||||
|
return startdir
|
||||||
|
|
||||||
|
|
||||||
|
def load_imagenet(resolution, args, **kwargs):
|
||||||
|
assert resolution == 32 or resolution == 64
|
||||||
|
|
||||||
|
args.input_size = [3, resolution, resolution]
|
||||||
|
|
||||||
|
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
|
||||||
|
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
|
||||||
|
|
||||||
|
trainpath = extract_tar(trainpath)
|
||||||
|
valpath = extract_tar(valpath)
|
||||||
|
|
||||||
|
data_transform = transforms.Compose([
|
||||||
|
ToTensorNoNorm()
|
||||||
|
])
|
||||||
|
|
||||||
|
print('Starting loading ImageNet')
|
||||||
|
|
||||||
|
imagenet_data = torchvision.datasets.ImageFolder(
|
||||||
|
trainpath,
|
||||||
|
transform=data_transform)
|
||||||
|
|
||||||
|
print('Number of data images', len(imagenet_data))
|
||||||
|
|
||||||
|
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
|
||||||
|
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
|
||||||
|
|
||||||
|
train_dataset = torch.utils.data.dataset.Subset(
|
||||||
|
imagenet_data, train_idcs)
|
||||||
|
val_dataset = torch.utils.data.dataset.Subset(
|
||||||
|
imagenet_data, val_idcs)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
val_loader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
test_dataset = torchvision.datasets.ImageFolder(
|
||||||
|
valpath,
|
||||||
|
transform=data_transform)
|
||||||
|
|
||||||
|
print('Number of val images:', len(test_dataset))
|
||||||
|
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return train_loader, val_loader, test_loader, args
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(args, **kwargs):
|
||||||
|
|
||||||
|
if args.dataset == 'cifar10':
|
||||||
|
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
|
||||||
|
elif args.dataset == 'imagenet32':
|
||||||
|
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
|
||||||
|
elif args.dataset == 'imagenet64':
|
||||||
|
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise Exception('Wrong name of the dataset!')
|
||||||
|
|
||||||
|
return train_loader, val_loader, test_loader, args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
class Args():
|
||||||
|
def __init__(self):
|
||||||
|
self.batch_size = 128
|
||||||
|
train_loader, val_loader, test_loader, args = load_imagenet32(Args())
|
||||||
56
integer_discrete_flows/utils/log_likelihood.py
Normal file
56
integer_discrete_flows/utils/log_likelihood.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import numpy as np
|
||||||
|
from scipy.misc import logsumexp
|
||||||
|
from optimization.loss import calculate_loss_array
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_likelihood(X, model, args, S=5000, MB=500):
|
||||||
|
|
||||||
|
# set auxiliary variables for number of training and test sets
|
||||||
|
N_test = X.size(0)
|
||||||
|
|
||||||
|
X = X.view(-1, *args.input_size)
|
||||||
|
|
||||||
|
likelihood_test = []
|
||||||
|
|
||||||
|
if S <= MB:
|
||||||
|
R = 1
|
||||||
|
else:
|
||||||
|
R = S // MB
|
||||||
|
S = MB
|
||||||
|
|
||||||
|
for j in range(N_test):
|
||||||
|
if j % 100 == 0:
|
||||||
|
print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100))
|
||||||
|
|
||||||
|
x_single = X[j].unsqueeze(0)
|
||||||
|
|
||||||
|
a = []
|
||||||
|
for r in range(0, R):
|
||||||
|
# Repeat it for all training points
|
||||||
|
x = x_single.expand(S, *x_single.size()[1:]).contiguous()
|
||||||
|
|
||||||
|
x_mean, z_mu, z_var, ldj, z0, zk = model(x)
|
||||||
|
|
||||||
|
a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args)
|
||||||
|
|
||||||
|
a.append(-a_tmp.cpu().data.numpy())
|
||||||
|
|
||||||
|
# calculate max
|
||||||
|
a = np.asarray(a)
|
||||||
|
a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
|
||||||
|
likelihood_x = logsumexp(a)
|
||||||
|
likelihood_test.append(likelihood_x - np.log(len(a)))
|
||||||
|
|
||||||
|
likelihood_test = np.array(likelihood_test)
|
||||||
|
|
||||||
|
nll = -np.mean(likelihood_test)
|
||||||
|
|
||||||
|
if args.input_type == 'multinomial':
|
||||||
|
bpd = nll/(np.prod(args.input_size) * np.log(2.))
|
||||||
|
elif args.input_type == 'binary':
|
||||||
|
bpd = 0.
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid input type!')
|
||||||
|
|
||||||
|
return nll, bpd
|
||||||
104
integer_discrete_flows/utils/plotting.py
Normal file
104
integer_discrete_flows/utils/plotting.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
# noninteractive background
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None):
|
||||||
|
"""
|
||||||
|
Plots train_loss and validation loss as a function of optimization iteration
|
||||||
|
:param train_loss: np.array of train_loss (1D or 2D)
|
||||||
|
:param validation_loss: np.array of validation loss (1D or 2D)
|
||||||
|
:param fname: output file name
|
||||||
|
:param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied
|
||||||
|
accross training curves.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
matplotlib.rcParams.update({'font.size': 14})
|
||||||
|
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||||||
|
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||||||
|
|
||||||
|
if len(train_loss.shape) == 1:
|
||||||
|
# Single training curve
|
||||||
|
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||||
|
figsize = (6, 4)
|
||||||
|
|
||||||
|
if train_loss.shape[0] == validation_loss.shape[0]:
|
||||||
|
# validation score evaluated every iteration
|
||||||
|
x = np.arange(train_loss.shape[0])
|
||||||
|
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||||
|
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||||
|
|
||||||
|
elif train_loss.shape[0] % validation_loss.shape[0] == 0:
|
||||||
|
# validation score evaluated every epoch
|
||||||
|
x = np.arange(train_loss.shape[0])
|
||||||
|
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||||
|
|
||||||
|
x = np.arange(validation_loss.shape[0])
|
||||||
|
x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0]
|
||||||
|
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||||
|
else:
|
||||||
|
raise ValueError('Length of train_loss and validation_loss must be equal or divisible')
|
||||||
|
|
||||||
|
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||||
|
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||||
|
ax.set_ylim([miny, maxy])
|
||||||
|
|
||||||
|
elif len(train_loss.shape) == 2:
|
||||||
|
# Multiple training curves
|
||||||
|
|
||||||
|
cmap = plt.cm.brg
|
||||||
|
|
||||||
|
cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0])
|
||||||
|
scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||||
|
figsize = (6, 4)
|
||||||
|
|
||||||
|
if labels is None:
|
||||||
|
labels = ['%d' % i for i in range(train_loss.shape[0])]
|
||||||
|
|
||||||
|
if train_loss.shape[1] == validation_loss.shape[1]:
|
||||||
|
for i in range(train_loss.shape[0]):
|
||||||
|
color_val = scalarMap.to_rgba(i)
|
||||||
|
|
||||||
|
# validation score evaluated every iteration
|
||||||
|
x = np.arange(train_loss.shape[0])
|
||||||
|
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||||
|
ax.plot(x, validation_loss[i], '--', lw=2., color=color_val)
|
||||||
|
|
||||||
|
elif train_loss.shape[1] % validation_loss.shape[1] == 0:
|
||||||
|
for i in range(train_loss.shape[0]):
|
||||||
|
color_val = scalarMap.to_rgba(i)
|
||||||
|
|
||||||
|
# validation score evaluated every epoch
|
||||||
|
x = np.arange(train_loss.shape[1])
|
||||||
|
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||||
|
|
||||||
|
x = np.arange(validation_loss.shape[1])
|
||||||
|
x = (x+1) * train_loss.shape[1] / validation_loss.shape[1]
|
||||||
|
ax.plot(x, validation_loss[i], '-', lw=2., color=color_val)
|
||||||
|
|
||||||
|
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||||
|
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||||
|
ax.set_ylim([miny, maxy])
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('train_loss and validation_loss must be 1D or 2D arrays')
|
||||||
|
|
||||||
|
ax.set_xlabel('iteration')
|
||||||
|
ax.set_ylabel('loss')
|
||||||
|
plt.title('Training and validation loss')
|
||||||
|
|
||||||
|
fig.set_size_inches(figsize)
|
||||||
|
fig.subplots_adjust(hspace=0.1)
|
||||||
|
plt.savefig(fname, bbox_inches='tight')
|
||||||
|
|
||||||
|
plt.close()
|
||||||
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
|
||||||
|
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
|
||||||
|
if epoch == 1:
|
||||||
|
if not os.path.exists(args.snap_dir + 'reconstruction/'):
|
||||||
|
os.makedirs(args.snap_dir + 'reconstruction/')
|
||||||
|
if loss_type == 'bpd':
|
||||||
|
fname = str(epoch) + '_bpd_%5.3f' % loss
|
||||||
|
elif loss_type == 'elbo':
|
||||||
|
fname = str(epoch) + '_elbo_%6.4f' % loss
|
||||||
|
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
|
||||||
|
batch, channels, height, width = x_sample.shape
|
||||||
|
|
||||||
|
print(x_sample.shape)
|
||||||
|
|
||||||
|
mosaic = np.zeros((height * size_y, width * size_x, channels))
|
||||||
|
|
||||||
|
for j in range(size_y):
|
||||||
|
for i in range(size_x):
|
||||||
|
idx = j * size_x + i
|
||||||
|
|
||||||
|
image = x_sample[idx]
|
||||||
|
|
||||||
|
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
|
||||||
|
image.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Remove channel for BW images
|
||||||
|
mosaic = mosaic.squeeze()
|
||||||
|
|
||||||
|
imageio.imwrite(dir + file_name + '.png', mosaic)
|
||||||
Reference in a new issue