feat: initial for IDF

This commit is contained in:
Robin Meersman 2025-11-07 12:54:36 +01:00
commit ef4684ef39
27 changed files with 2830 additions and 0 deletions

View file

@ -0,0 +1,19 @@
Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,29 @@
# Integer Discrete Flows and Lossless Compression
This repository contains the code for the experiments presented in [1].
## Usage
### CIFAR10 setup:
```
python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type densenet
```
### ImageNet32 setup:
```
python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99
```
### ImageNet64 setup:
```
python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64
```
# Acknowledgements
The Robert Bosch GmbH is acknowledged for financial support.
# References
[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019).

View file

@ -0,0 +1,132 @@
import numpy as np
from . import rans
from utils.distributions import discretized_logistic_cdf, \
mixture_discretized_logistic_cdf
import torch
precision = 24
n_bins = 4096
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
if variable_type == 'discrete':
if distribution_type == 'logistic':
if len(pz) == 2:
return discretized_logistic_cdf(
z, *pz, inverse_bin_width=inverse_bin_width)
elif len(pz) == 3:
return mixture_discretized_logistic_cdf(
z, *pz, inverse_bin_width=inverse_bin_width)
elif distribution_type == 'normal':
pass
elif variable_type == 'continuous':
if distribution_type == 'logistic':
pass
elif distribution_type == 'normal':
pass
elif distribution_type == 'steplogistic':
pass
raise ValueError
def CDF_fn(pz, bin_width, variable_type, distribution_type):
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
MEAN = torch.round(mean / bin_width).long()
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
bin_locations = bin_locations.float() * bin_width
bin_locations = bin_locations.to(device=pz[0].device)
pz = [param[:, :, :, :, None] for param in pz]
cdf = cdf_fn(
bin_locations - bin_width,
pz,
variable_type,
distribution_type,
1./bin_width).cpu().numpy()
# Compute CDFs, reweigh to give all bins at least
# 1 / (2^precision) probability.
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
+ np.arange(n_bins)
return CDFs, MEAN
def encode_sample(
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
if state is None:
state = rans.x_init
else:
state = rans.unflatten(state)
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
# z is transformed to Z to match the indices for the CDFs array
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
Z = Z.cpu().numpy()
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
print('Z out of allowed range of values, canceling compression')
return None
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
statfun = statfun_encode(cdf)
state = rans.append_symbol(statfun, precision)(state, symbol)
state = rans.flatten(state)
return state
def decode_sample(
state, pz, variable_type, distribution_type, bin_width=1./256):
state = rans.unflatten(state)
device = pz[0].device
size = pz[0].size()[0:4]
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
CDFs = CDFs.reshape(-1, n_bins)
result = np.zeros(len(CDFs), dtype=int)
for i, cdf in enumerate(CDFs):
statfun = statfun_decode(cdf)
state, symbol = rans.pop_symbol(statfun, precision)(state)
result[i] = symbol
Z_flat = torch.from_numpy(result).to(device)
Z = Z_flat.view(size) - n_bins // 2 + MEAN
z = Z.float() * bin_width
state = rans.flatten(state)
return state, z
def statfun_encode(CDF):
def _statfun_encode(symbol):
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
return _statfun_encode
def statfun_decode(CDF):
def _statfun_decode(cf):
# Search such that CDF[s] <= cf < CDF[s]
s = np.searchsorted(CDF, cf, side='right')
s = s - 1
start, freq = statfun_encode(CDF)(s)
return s, (start, freq)
return _statfun_decode
def encode(x, symbol):
return rans.append_symbol(statfun_encode, precision)(x, symbol)
def decode(x):
return rans.pop_symbol(statfun_decode, precision)(x)

View file

@ -0,0 +1,67 @@
"""
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h
ROUGH GUIDE:
We use the pythonic names 'append' and 'pop' for encoding and decoding
respectively. The compressed state 'x' is an immutable stack, implemented using
a cons list.
x: the current stack-like state of the encoder/decoder.
precision: the natural numbers are divided into ranges of size 2^precision.
start & freq: start indicates the beginning of the range in [0, 2^precision-1]
that the current symbol is represented by. freq is the length of the range.
freq is chosen such that p(symbol) ~= freq/2^precision.
"""
import numpy as np
from functools import reduce
rans_l = 1 << 31 # the lower bound of the normalisation interval
tail_bits = (1 << 32) - 1
x_init = (rans_l, ())
def append(x, start, freq, precision):
"""Encodes a symbol with range [start, start + freq). All frequencies are
assumed to sum to "1 << precision", and the resulting bits get written to
x."""
if x[0] >= ((rans_l >> precision) << 32) * freq:
x = (x[0] >> 32, (x[0] & tail_bits, x[1]))
return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1]
def pop(x_, precision):
"""Advances in the bit stream by "popping" a single symbol with range start
"start" and frequency "freq"."""
cf = x_[0] & ((1 << precision) - 1)
def pop(start, freq):
x = freq * (x_[0] >> precision) + cf - start, x_[1]
return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x
return cf, pop
def append_symbol(statfun, precision):
def append_(x, symbol):
start, freq = statfun(symbol)
return append(x, start, freq, precision)
return append_
def pop_symbol(statfun, precision):
def pop_(x):
cf, pop_fun = pop(x, precision)
symbol, (start, freq) = statfun(cf)
return pop_fun(start, freq), symbol
return pop_
def flatten(x):
"""Flatten a rans state x into a 1d numpy array."""
out, x = [x[0] >> 32, x[0]], x[1]
while x:
x_head, x = x
out.append(x_head)
return np.asarray(out, dtype=np.uint32)
def unflatten(arr):
"""Unflatten a 1d numpy array into a rans state."""
return (int(arr[0]) << 32 | int(arr[1]),
reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))

View file

@ -0,0 +1,188 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import torch
import torch.utils.data
import numpy as np
from utils.load_data import load_dataset
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
metavar='DATASET',
help='Dataset choice.')
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
help='how many batches to wait before logging training status')
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
help='Evaluate per how many epochs')
# optimization settings
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
help='number of epochs to train (default: 2000)')
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
help='number of early stopping epochs')
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
help='input batch size for training (default: 100)')
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
help='learning rate')
parser.add_argument('--warmup', type=int, default=10,
help='number of warmup epochs')
parser.add_argument('--data_augmentation_level', type=int, default=2,
help='data augmentation level')
parser.add_argument('--no_decode', action='store_true', default=False,
help='disables decoding')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
def encode_images(img, model, decode):
batchsize, img_c, img_h, img_w = img.size()
c, h, w = model.args.input_size
assert img_h == img_w and h == w
if img_h != h:
assert img_h % h == 0
steps = img_h // h
states = [[] for i in range(batchsize)]
state_sizes = [0 for i in range(batchsize)]
bpd = [0 for i in range(batchsize)]
error = 0
for j in range(steps):
for i in range(steps):
r = encode_patches(
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
for b in range(batchsize):
if r[0][b] is None:
states[b].append(None)
else:
states[b].extend(r[0][b])
state_sizes[b] += r[1][b]
bpd[b] += r[2][b] / steps**2
error += r[3]
return states, state_sizes, bpd, error
else:
return encode_patches(img, model, decode)
def encode_patches(imgs, model, decode):
batchsize, img_c, img_h, img_w = imgs.size()
c, h, w = model.args.input_size
assert img_h == h and img_w == w
states = model.encode(imgs)
bpd = model.forward(imgs)[1].cpu().numpy()
state_sizes = []
error = 0
for b in range(batchsize):
if states[b] is None:
# Using escape bit ;)
state_sizes += [8 * img_c * img_h * img_w + 1]
# Error remains unchanged.
print('Escaping, not encoding.')
else:
if decode:
x_recon = model.decode([states[b]])
error += torch.sum(
torch.abs(x_recon.int() - imgs[b].int())).item()
# Append state plus an escape bit
state_sizes += [32 * len(states[b]) + 1]
return states, state_sizes, bpd, error
def run(args, kwargs):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
args.snap_dir = snap_dir = \
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
# ==================================================================================================================
# SNAPSHOTS
# ==================================================================================================================
# ==================================================================================================================
# LOAD DATA
# ==================================================================================================================
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
final_model = torch.load(snap_dir + 'a.model')
if hasattr(final_model, 'module'):
final_model = final_model.module
final_model = final_model.cuda()
sizes = []
errors = []
bpds = []
import time
start = time.time()
t = 0
with torch.no_grad():
for data, _ in test_loader:
if args.cuda:
data = data.cuda()
state, state_sizes, bpd, error = \
encode_images(data, final_model, decode=not args.no_decode)
errors += [error]
bpds.extend(bpd)
sizes.extend(state_sizes)
t += len(data)
print(
'Examples: {}/{} bpd compression: {:.3f} error: {},'
' analytical bpd {:.3f}'.format(
t, len(test_loader.dataset),
np.mean(sizes) / np.prod(data.size()[1:]),
np.sum(errors),
np.mean(bpds)
))
if args.no_decode:
print('Not testing decoding.')
else:
print('Error: {}'.format(np.sum(errors)))
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
print('Final bpd: {:.3f} error: {}'.format(
np.mean(sizes) / np.prod(data.size()[1:]),
np.sum(errors)))
if __name__ == "__main__":
run(args, kwargs)

View file

@ -0,0 +1,105 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os
from optimization.training import train, evaluate
from utils.load_data import load_dataset
from utils.plotting import plot_training_curve
import imageio
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
metavar='DATASET',
help='Dataset choice.')
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
help='input batch size for training (default: 100)')
parser.add_argument('--data_augmentation_level', type=int, default=2,
help='data augmentation level')
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
help='disables CUDA training')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
def run(args, kwargs):
args.snap_dir = snap_dir = \
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
# ==================================================================================================================
# SNAPSHOTS
# ==================================================================================================================
# ==================================================================================================================
# LOAD DATA
# ==================================================================================================================
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
final_model = torch.load(snap_dir + 'a.model')
if hasattr(final_model, 'module'):
final_model = final_model.module
from models.backround import SmoothRound
for module in final_model.modules():
if isinstance(module, SmoothRound):
module._round_decay = 1.
exp_dir = snap_dir + 'partials/'
os.makedirs(exp_dir, exist_ok=True)
images = []
with torch.no_grad():
for data, _ in test_loader:
if args.cuda:
data = data.cuda()
for i in range(len(data)):
_, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1])
for j in range(len(ys) + 1):
x_recon = final_model.inverse(
z,
ys[len(ys) - j:])
images.append(x_recon.float())
if i == 10:
break
break
for j in range(len(ys) + 1):
grid = make_grid(
torch.stack(images[j::len(ys) + 1], dim=0).squeeze(),
nrow=11, padding=0,
normalize=True, range=None,
scale_each=False, pad_value=0)
imageio.imwrite(
exp_dir + 'loaded{j}.png'.format(j=j),
grid.cpu().numpy().transpose(1, 2, 0))
if __name__ == "__main__":
run(args, kwargs)

View file

@ -0,0 +1,285 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import numpy as np
import math
import random
import os
import datetime
from optimization.training import train, evaluate
from utils.load_data import load_dataset
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
parser.add_argument('-d', '--dataset', type=str, default='cifar10',
choices=['cifar10', 'imagenet32', 'imagenet64'],
metavar='DATASET',
help='Dataset choice.')
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
help='how many batches to wait before logging training status')
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
help='Evaluate per how many epochs')
parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
help='output directory for model snapshots etc.')
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('-te', '--testing', action='store_true', dest='testing',
help='evaluate on test set after training')
fp.add_argument('-va', '--validation', action='store_false', dest='testing',
help='only evaluate on validation set')
parser.set_defaults(testing=True)
# optimization settings
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
help='number of epochs to train (default: 2000)')
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
help='number of early stopping epochs')
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
help='input batch size for training (default: 100)')
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
help='learning rate')
parser.add_argument('--warmup', type=int, default=10,
help='number of warmup epochs')
parser.add_argument('--data_augmentation_level', type=int, default=2,
help='data augmentation level')
parser.add_argument('--variable_type', type=str, default='discrete',
help='variable type of data distribution: discrete/continuous',
choices=['discrete', 'continuous'])
parser.add_argument('--distribution_type', type=str, default='logistic',
choices=['logistic', 'normal', 'steplogistic'],
help='distribution type: logistic/normal')
parser.add_argument('--n_flows', type=int, default=8,
help='number of flows per level')
parser.add_argument('--n_levels', type=int, default=3,
help='number of levels')
parser.add_argument('--n_bits', type=int, default=8,
help='')
# ---------------- SETTINGS CONCERNING NETWORKS -------------
parser.add_argument('--densenet_depth', type=int, default=8,
help='Depth of densenets')
parser.add_argument('--n_channels', type=int, default=512,
help='number of channels in coupling and splitprior')
# ---------------- ----------------------------- -------------
# ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
parser.add_argument('--coupling_type', type=str, default='shallow',
choices=['shallow', 'resnet', 'densenet'],
help='Type of coupling layer')
parser.add_argument('--splitfactor', default=0, type=int,
help='Split factor for coupling layers.')
parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
help='Split coupling layer on quarter')
parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
parser.set_defaults(split_quarter=True)
# ---------------- ----------------------------------- -------------
# ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
parser.add_argument('--splitprior_type', type=str, default='shallow',
choices=['none', 'shallow', 'resnet', 'densenet'],
help='Type of splitprior. Use \'none\' for no splitprior')
# ---------------- ------------------------------- -------------
# ---------------- SETTINGS CONCERNING PRIORS -------------
parser.add_argument('--n_mixtures', type=int, default=1,
help='number of mixtures')
# ---------------- ------------------------------- -------------
parser.add_argument('--hard_round', dest='hard_round', action='store_true',
help='Rounding of translation in discrete models. Weird '
'probabilistic implications, only for experimental phase')
parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
parser.set_defaults(hard_round=True)
parser.add_argument('--round_approx', type=str, default='smooth',
choices=['smooth', 'stochastic'])
parser.add_argument('--lr_decay', default=0.999, type=float,
help='Learning rate')
parser.add_argument('--temperature', default=1.0, type=float,
help='Temperature used for BackRound. It is used in '
'the the SmoothRound module. '
'(default=1.0')
# gpu/cpu
parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU',
help='choose GPU to run on.')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.manual_seed is None:
args.manual_seed = random.randint(1, 100000)
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
np.random.seed(args.manual_seed)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
def run(args, kwargs):
print('\nMODEL SETTINGS: \n', args, '\n')
print("Random Seed: ", args.manual_seed)
if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5:
args.evaluate_interval_epochs = 5
# ==================================================================================================================
# SNAPSHOTS
# ==================================================================================================================
args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
args.model_signature = args.model_signature.replace(':', '_')
snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset)
snap_dir = snapshots_path
snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels)
snap_dir = snap_dir + '__' + args.model_signature + '/'
args.snap_dir = snap_dir
if not os.path.exists(snap_dir):
os.makedirs(snap_dir)
with open(snap_dir + 'log.txt', 'a') as ff:
print('\nMODEL SETTINGS: \n', args, '\n', file=ff)
# SAVING
torch.save(args, snap_dir + '.config')
# ==================================================================================================================
# LOAD DATA
# ==================================================================================================================
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
# ==================================================================================================================
# SELECT MODEL
# ==================================================================================================================
# flow parameters and architecture choice are passed on to model through args
print(args.input_size)
import models.Model as Model
model = Model.Model(args)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.set_temperature(args.temperature)
model.enable_hard_round(args.hard_round)
model_sample = model
# ====================================
# INIT
# ====================================
# data dependend initialization on CPU
for batch_idx, (data, _) in enumerate(train_loader):
model(data)
break
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, dim=0)
model.to(args.device)
def lr_lambda(epoch):
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
# ==================================================================================================================
# TRAINING
# ==================================================================================================================
train_bpd = []
val_bpd = []
# for early stopping
best_val_bpd = np.inf
best_train_bpd = np.inf
epoch = 0
train_times = []
model.eval()
model.train()
for epoch in range(1, args.epochs + 1):
t_start = time.time()
scheduler.step()
tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args)
train_bpd.append(tr_bpd)
train_times.append(time.time()-t_start)
print('One training epoch took %.2f seconds' % (time.time()-t_start))
if epoch < 25 or epoch % args.evaluate_interval_epochs == 0:
v_loss, v_bpd = evaluate(
train_loader, val_loader, model, model_sample, args,
epoch=epoch, file=snap_dir + 'log.txt')
val_bpd.append(v_bpd)
# Model save based on TRAIN performance (is heavily correlated with validation performance.)
if np.mean(tr_bpd) < best_train_bpd:
best_train_bpd = np.mean(tr_bpd)
best_val_bpd = v_bpd
torch.save(model.module, snap_dir + 'a.model')
torch.save(optimizer, snap_dir + 'a.optimizer')
print('->model saved<-')
print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format(
best_train_bpd, best_val_bpd))
if math.isnan(v_loss):
raise ValueError('NaN encountered!')
train_bpd = np.hstack(train_bpd)
val_bpd = np.array(val_bpd)
# training time per epoch
train_times = np.array(train_times)
mean_train_time = np.mean(train_times)
std_train_time = np.std(train_times, ddof=1)
print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time))
# ==================================================================================================================
# EVALUATION
# ==================================================================================================================
final_model = torch.load(snap_dir + 'a.model')
test_loss, test_bpd = evaluate(
train_loader, test_loader, final_model, final_model, args,
epoch=epoch, file=snap_dir + 'test_log.txt')
print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd))
if __name__ == "__main__":
run(args, kwargs)

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,191 @@
import torch
import models.generative_flows as generative_flows
import numpy as np
from models.utils import Base
from .priors import Prior
from optimization.loss import compute_loss_array
from coding.coder import encode_sample, decode_sample
class Normalize(Base):
def __init__(self, args):
super().__init__()
self.n_bits = args.n_bits
self.variable_type = args.variable_type
self.input_size = args.input_size
def forward(self, x, ldj, reverse=False):
domain = 2.**self.n_bits
if self.variable_type == 'discrete':
# Discrete variables will be measured on intervals sized 1/domain.
# Hence, there is no need to change the log Jacobian determinant.
dldj = 0
elif self.variable_type == 'continuous':
dldj = -np.log(domain) * np.prod(self.input_size)
else:
raise ValueError
if not reverse:
x = (x - domain / 2) / domain
ldj += dldj
else:
x = x * domain + domain / 2
ldj -= dldj
return x, ldj
class Model(Base):
"""
The base VAE class containing gated convolutional encoder and decoder
architecture. Can be used as a base class for VAE's with normalizing flows.
"""
def __init__(self, args):
super().__init__()
self.args = args
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
n_channels, height, width = args.input_size
self.normalize = Normalize(args)
self.flow = generative_flows.GenerativeFlow(
n_channels, height, width, args)
self.n_bits = args.n_bits
self.z_size = self.flow.z_size
self.prior = Prior(self.z_size, args)
def dequantize(self, x):
if self.training:
x = x + torch.rand_like(x)
else:
# Required for stability.
alpha = 1e-3
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
return x
def loss(self, pz, z, pys, ys, ldj):
batchsize = z.size(0)
loss, bpd, bpd_per_prior = \
compute_loss_array(pz, z, pys, ys, ldj, self.args)
for module in self.modules():
if hasattr(module, 'auxillary_loss'):
loss += module.auxillary_loss() / batchsize
return loss, bpd, bpd_per_prior
def forward(self, x):
"""
Evaluates the model as a whole, encodes and decodes. Note that the log
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
"""
# Decode z to x.
assert x.dtype == torch.uint8
x = x.float()
ldj = torch.zeros_like(x[:, 0, 0, 0])
if self.variable_type == 'continuous':
x = self.dequantize(x)
elif self.variable_type == 'discrete':
pass
else:
raise ValueError
x, ldj = self.normalize(x, ldj)
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
pz, z, ldj = self.prior(z, ldj)
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
def inverse(self, z, ys):
ldj = torch.zeros_like(z[:, 0, 0, 0])
x, ldj, pys, py = \
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
x, ldj = self.normalize(x, ldj, reverse=True)
x_uint8 = torch.clamp(x, min=0, max=255).to(
torch.uint8)
return x_uint8
def sample(self, n):
z_sample = self.prior.sample(n)
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
x_sample, ldj, pys, py = \
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
torch.uint8)
return x_sample_uint8
def encode(self, x):
batchsize = x.size(0)
_, _, _, pz, z, pys, ys, _ = self.forward(x)
pjs = list(pys) + [pz]
js = list(ys) + [z]
states = []
for b in range(batchsize):
state = None
for pj, j in zip(pjs, js):
pj_b = [param[b:b+1] for param in pj]
j_b = j[b:b+1]
state = encode_sample(
j_b, pj_b, self.variable_type,
self.distribution_type, state=state)
if state is None:
break
states.append(state)
return states
def decode(self, states):
def decode_fn(states, pj):
states = list(states)
j = []
for b in range(len(states)):
pj_b = [param[b:b+1] for param in pj]
states[b], j_b = decode_sample(
states[b], pj_b, self.variable_type,
self.distribution_type)
j.append(j_b)
j = torch.cat(j, dim=0)
return states, j
states, z = self.prior.decode(states, decode_fn=decode_fn)
ldj = torch.zeros_like(z[:, 0, 0, 0])
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
x, ldj = self.normalize(x, ldj, reverse=True)
x = x.to(dtype=torch.uint8)
return x

View file

@ -0,0 +1,151 @@
import torch
import torch.nn.functional as F
import numpy as np
from models.utils import Base
class RoundStraightThrough(torch.autograd.Function):
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, input):
rounded = torch.round(input, out=None)
return rounded
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
_round_straightthrough = RoundStraightThrough().apply
def _stacked_sigmoid(x, temperature, n_approx=3):
x_ = x - 0.5
rounded = torch.round(x_)
x_remainder = x_ - rounded
size = x_.size()
x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.device, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
return out + rounded - (n_approx // 2)
class SmoothRound(Base):
def __init__(self):
self._temperature = None
self._n_approx = None
super().__init__()
self.hard_round = None
@property
def temperature(self):
return self._temperature
@temperature.setter
def temperature(self, value):
self._temperature = value
if self._temperature <= 0.05:
self._n_approx = 1
elif 0.05 < self._temperature < 0.13:
self._n_approx = 3
else:
self._n_approx = 5
def forward(self, x):
assert self._temperature is not None
assert self._n_approx is not None
assert self.hard_round is not None
if self.temperature <= 0.25:
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
else:
h = x
if self.hard_round:
h = _round_straightthrough(h)
return h
class StochasticRound(Base):
def __init__(self):
super().__init__()
self.hard_round = None
def forward(self, x):
u = torch.rand_like(x)
h = x + u - 0.5
if self.hard_round:
h = _round_straightthrough(h)
return h
class BackRound(Base):
def __init__(self, args, inverse_bin_width):
"""
BackRound is an approximation to Round that allows for Backpropagation.
Approximate the round function using a sum of translated sigmoids.
The temperature determines how well the round function is approximated,
i.e., a lower temperature corresponds to a better approximation, at
the cost of more vanishing gradients.
BackRound supports the following settings:
* By setting hard to True and temperature > 0.25, BackRound
reduces to a round function with a straight through gradient
estimator
* When using 0 < temperature <= 0.25 and hard = True, the
output in the forward pass is equivalent to a round function, but the
gradient is approximated by the gradient of a sum of sigmoids.
* When using hard = False, the output is not constrained to integers.
* When temperature > 0.25 and hard = False, BackRound reduces to
the identity function.
Arguments
---------
temperature: float
Temperature used for stacked sigmoid approximated. If temperature
is greater than 0.25, the approximation reduces to the indentiy
function.
hard: bool
If hard is True, a (hard) round is applied before returning. The
gradient for this is approximated using the straight-through
estimator.
"""
super().__init__()
self.inverse_bin_width = inverse_bin_width
self.round_approx = args.round_approx
if args.round_approx == 'smooth':
self.round = SmoothRound()
elif args.round_approx == 'stochastic':
self.round = StochasticRound()
else:
raise ValueError
def forward(self, x):
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
h = x * self.inverse_bin_width
h = self.round(h)
return h / self.inverse_bin_width
else:
raise ValueError

View file

@ -0,0 +1,142 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import numpy as np
from models.utils import Base
from .backround import BackRound
from .networks import NN
UNIT_TESTING = False
class SplitFactorCoupling(Base):
def __init__(self, c_in, factor, height, width, args):
super().__init__()
self.n_channels = args.n_channels
self.kernel = 3
self.input_channel = c_in
self.round_approx = args.round_approx
if args.variable_type == 'discrete':
self.round = BackRound(
args, inverse_bin_width=2**args.n_bits)
else:
self.round = None
self.split_idx = c_in - (c_in // factor)
self.nn = NN(
args=args,
c_in=self.split_idx,
c_out=c_in - self.split_idx,
height=height,
width=width,
kernel=self.kernel,
nn_type=args.coupling_type)
def forward(self, z, ldj, reverse=False):
z1 = z[:, :self.split_idx, :, :]
z2 = z[:, self.split_idx:, :, :]
t = self.nn(z1)
if self.round is not None:
t = self.round(t)
if not reverse:
z2 = z2 + t
else:
z2 = z2 - t
z = torch.cat([z1, z2], dim=1)
return z, ldj
class Coupling(Base):
def __init__(self, c_in, height, width, args):
super().__init__()
if args.split_quarter:
factor = 4
elif args.splitfactor > 1:
factor = args.splitfactor
else:
factor = 2
self.coupling = SplitFactorCoupling(
c_in, factor, height, width, args=args)
def forward(self, z, ldj, reverse=False):
return self.coupling(z, ldj, reverse)
def test_generative_flow():
import models.networks as networks
global UNIT_TESTING
networks.UNIT_TESTING = True
UNIT_TESTING = True
batch_size = 17
input_size = [12, 16, 16]
class Args():
def __init__(self):
self.input_size = input_size
self.learn_split = False
self.variable_type = 'continuous'
self.distribution_type = 'logistic'
self.round_approx = 'smooth'
self.coupling_type = 'shallow'
self.conv_type = 'standard'
self.densenet_depth = 8
self.bottleneck = False
self.n_channels = 512
self.network1x1 = 'standard'
self.auxilary_freq = -1
self.actnorm = False
self.LU = False
self.coupling_lifting_L = True
self.splitprior = True
self.split_quarter = True
self.n_levels = 2
self.n_flows = 2
self.cond_L = True
self.n_bits = True
args = Args()
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
ldj = torch.zeros_like(x[:, 0, 0, 0])
model = Coupling(c_in=12, height=16, width=16, args=args)
print(model)
model.set_temperature(1.)
model.enable_hard_round()
model.eval()
z, ldj = model(x, ldj, reverse=False)
# Check if gradient computation works
loss = torch.sum(z**2)
loss.backward()
recon, ldj = model(z, ldj, reverse=True)
sse = torch.sum(torch.pow(x - recon, 2)).item()
ae = torch.abs(x - recon).sum()
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
if __name__ == '__main__':
test_generative_flow()

View file

@ -0,0 +1,175 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import numpy as np
from models.utils import Base
from .priors import SplitPrior
from .coupling import Coupling
UNIT_TESTING = False
def space_to_depth(x):
xs = x.size()
# Pick off every second element
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
# Transpose picked elements next to channels.
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
# Combine with channels.
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
return x
def depth_to_space(x):
xs = x.size()
# Pick off elements from channels
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
# Transpose picked elements next to HW dimensions.
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
# Combine with HW dimensions.
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
return x
def int_shape(x):
return list(map(int, x.size()))
class Flatten(Base):
def forward(self, x):
return x.view(x.size(0), -1)
class Reshape(Base):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(x.size(0), *self.shape)
class Reverse(Base):
def __init__(self):
super().__init__()
def forward(self, z, reverse=False):
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
z = z[:, flip_idx, :, :]
return z
class Permute(Base):
def __init__(self, n_channels):
super().__init__()
permutation = np.arange(n_channels, dtype='int')
np.random.shuffle(permutation)
permutation_inv = np.zeros(n_channels, dtype='int')
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
self.permutation = torch.from_numpy(permutation)
self.permutation_inv = torch.from_numpy(permutation_inv)
def forward(self, z, ldj, reverse=False):
if not reverse:
z = z[:, self.permutation, :, :]
else:
z = z[:, self.permutation_inv, :, :]
return z, ldj
def InversePermute(self):
inv_permute = Permute(len(self.permutation))
inv_permute.permutation = self.permutation_inv
inv_permute.permutation_inv = self.permutation
return inv_permute
class Squeeze(Base):
def __init__(self):
super().__init__()
def forward(self, z, ldj, reverse=False):
if not reverse:
z = space_to_depth(z)
else:
z = depth_to_space(z)
return z, ldj
class GenerativeFlow(Base):
def __init__(self, n_channels, height, width, args):
super().__init__()
layers = []
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
for level in range(args.n_levels):
for i in range(args.n_flows):
perm_layer = Permute(n_channels)
layers.append(perm_layer)
layers.append(
Coupling(n_channels, height, width, args))
if level < args.n_levels - 1:
if args.splitprior_type != 'none':
# Standard splitprior
factor_out = n_channels // 2
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
n_channels = n_channels - factor_out
layers.append(Squeeze())
n_channels *= 4
height //= 2
width //= 2
self.layers = torch.nn.ModuleList(layers)
self.z_size = (n_channels, height, width)
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
if not reverse:
for l, layer in enumerate(self.layers):
if isinstance(layer, (SplitPrior)):
py, y, z, ldj = layer(z, ldj)
pys += (py,)
ys += (y,)
else:
z, ldj = layer(z, ldj)
else:
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, (SplitPrior)):
if len(ys) > 0:
z, ldj = layer.inverse(z, ldj, y=ys[-1])
# Pop last element
ys = ys[:-1]
else:
z, ldj = layer.inverse(z, ldj, y=None)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj, pys, ys
def decode(self, z, ldj, state, decode_fn):
for l, layer in reversed(list(enumerate(self.layers))):
if isinstance(layer, SplitPrior):
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
else:
z, ldj = layer(z, ldj, reverse=True)
return z, ldj

View file

@ -0,0 +1,154 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.utils import Base
UNIT_TESTING = False
class Conv2dReLU(Base):
def __init__(
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
bias=True):
super().__init__()
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
def forward(self, x):
h = self.nn(x)
y = F.relu(h)
return y
class ResidualBlock(Base):
def __init__(self, n_channels, kernel, Conv2dAct):
super().__init__()
self.nn = torch.nn.Sequential(
Conv2dAct(n_channels, n_channels, kernel, padding=1),
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
)
def forward(self, x):
h = self.nn(x)
h = F.relu(h + x)
return h
class DenseLayer(Base):
def __init__(self, args, n_inputs, growth, Conv2dAct):
super().__init__()
conv1x1 = Conv2dAct(
n_inputs, n_inputs, kernel_size=1, stride=1,
padding=0, bias=True)
self.nn = torch.nn.Sequential(
conv1x1,
Conv2dAct(
n_inputs, growth, kernel_size=3, stride=1,
padding=1, bias=True),
)
def forward(self, x):
h = self.nn(x)
h = torch.cat([x, h], dim=1)
return h
class DenseBlock(Base):
def __init__(
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
super().__init__()
depth = args.densenet_depth
future_growth = n_outputs - n_inputs
layers = []
for d in range(depth):
growth = future_growth // (depth - d)
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
n_inputs += growth
future_growth -= growth
self.nn = torch.nn.Sequential(*layers)
def forward(self, x):
return self.nn(x)
class Identity(Base):
def __init__(self):
super.__init__()
def forward(self, x):
return x
class NN(Base):
def __init__(
self, args, c_in, c_out, height, width, nn_type, kernel=3):
super().__init__()
Conv2dAct = Conv2dReLU
n_channels = args.n_channels
if nn_type == 'shallow':
if args.network1x1 == 'standard':
conv1x1 = Conv2dAct(
n_channels, n_channels, kernel_size=1, stride=1,
padding=0, bias=False)
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
conv1x1]
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
elif nn_type == 'resnet':
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
ResidualBlock(n_channels, kernel, Conv2dAct),
ResidualBlock(n_channels, kernel, Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
]
elif nn_type == 'densenet':
layers = [
DenseBlock(
args=args,
n_inputs=c_in,
n_outputs=n_channels + c_in,
kernel=kernel,
Conv2dAct=Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
]
else:
raise ValueError
self.nn = torch.nn.Sequential(*layers)
# Set parameters of last conv-layer to zero.
if not UNIT_TESTING:
self.nn[-1].weight.data.zero_()
self.nn[-1].bias.data.zero_()
def forward(self, x):
return self.nn(x)

View file

@ -0,0 +1,164 @@
"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from utils.distributions import sample_discretized_logistic, \
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
sample_discretized_normal, sample_mixture_normal
from models.utils import Base
from .networks import NN
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
if variable_type == 'discrete':
if distribution_type == 'logistic':
if len(px) == 2:
return sample_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif len(px) == 3:
return sample_mixture_discretized_logistic(
*px, inverse_bin_width=inverse_bin_width)
elif distribution_type == 'normal':
return sample_discretized_normal(
*px, inverse_bin_width=inverse_bin_width)
elif variable_type == 'continuous':
if distribution_type == 'logistic':
return sample_logistic(*px)
elif distribution_type == 'normal':
if len(px) == 2:
return sample_normal(*px)
elif len(px) == 3:
return sample_mixture_normal(*px)
elif distribution_type == 'steplogistic':
return sample_logistic(*px)
raise ValueError
class Prior(Base):
def __init__(self, size, args):
super().__init__()
c, h, w = size
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.n_mixtures = args.n_mixtures
if self.n_mixtures == 1:
self.mu = Parameter(torch.Tensor(c, h, w))
self.logs = Parameter(torch.Tensor(c, h, w))
elif self.n_mixtures > 1:
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
self.reset_parameters()
def reset_parameters(self):
self.mu.data.zero_()
if self.n_mixtures > 1:
self.pi_logit.data.zero_()
for i in range(self.n_mixtures):
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
self.logs.data.zero_()
def get_pz(self, n):
if self.n_mixtures == 1:
mu = self.mu.repeat(n, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
return mu, logs
elif self.n_mixtures > 1:
pi = F.softmax(self.pi_logit, dim=-1)
mu = self.mu.repeat(n, 1, 1, 1, 1)
logs = self.logs.repeat(n, 1, 1, 1, 1)
pi = pi.repeat(n, 1, 1, 1, 1)
return mu, logs, pi
def forward(self, z, ldj):
pz = self.get_pz(z.size(0))
return pz, z, ldj
def sample(self, n):
pz = self.get_pz(n)
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
return z_sample
def decode(self, states, decode_fn):
pz = self.get_pz(n=len(states))
states, z = decode_fn(states, pz)
return states, z
class SplitPrior(Base):
def __init__(self, c_in, factor_out, height, width, args):
super().__init__()
self.split_idx = c_in - factor_out
self.inverse_bin_width = 2**args.n_bits
self.variable_type = args.variable_type
self.distribution_type = args.distribution_type
self.input_channel = c_in
self.nn = NN(
args=args,
c_in=c_in - factor_out,
c_out=factor_out * 2,
height=height,
width=width,
nn_type=args.splitprior_type)
def get_py(self, z):
h = self.nn(z)
mu = h[:, ::2, :, :]
logs = h[:, 1::2, :, :]
py = [mu, logs]
return py
def split(self, z):
z1 = z[:, :self.split_idx, :, :]
y = z[:, self.split_idx:, :, :]
return z1, y
def combine(self, z, y):
result = torch.cat([z, y], dim=1)
return result
def forward(self, z, ldj):
z, y = self.split(z)
py = self.get_py(z)
return py, y, z, ldj
def inverse(self, z, ldj, y):
# Sample if y is not given.
if y is None:
py = self.get_py(z)
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
z = self.combine(z, y)
return z, ldj
def decode(self, z, ldj, states, decode_fn):
py = self.get_py(z)
states, y = decode_fn(states, py)
return self.combine(z, y), ldj, states

View file

@ -0,0 +1,36 @@
import torch
class Base(torch.nn.Module):
"""
The base class for modules. That contains a disable round mode
"""
def __init__(self):
super().__init__()
def _set_child_attribute(self, attr, value):
r"""Sets the module in rounding mode.
This has any effect only on certain modules if variable type is
discrete.
Returns:
Module: self
"""
if hasattr(self, attr):
setattr(self, attr, value)
for module in self.modules():
if hasattr(module, attr):
setattr(module, attr, value)
return self
def set_temperature(self, value):
self._set_child_attribute("temperature", value)
def enable_hard_round(self, mode=True):
self._set_child_attribute("hard_round", mode)
def disable_hard_round(self, mode=True):
self.enable_hard_round(not mode)

View file

@ -0,0 +1,148 @@
from __future__ import print_function
import numpy as np
import torch
from utils.distributions import log_discretized_logistic, \
log_mixture_discretized_logistic, log_normal, log_discretized_normal, \
log_logistic, log_mixture_normal
from models.backround import _round_straightthrough
def compute_log_ps(pxs, xs, args):
# Add likelihoods of intermediate representations.
inverse_bin_width = 2.**args.n_bits
log_pxs = []
for px, x in zip(pxs, xs):
if args.variable_type == 'discrete':
if args.distribution_type == 'logistic':
log_px = log_discretized_logistic(
x, *px, inverse_bin_width=inverse_bin_width)
elif args.distribution_type == 'normal':
log_px = log_discretized_normal(
x, *px, inverse_bin_width=inverse_bin_width)
elif args.variable_type == 'continuous':
if args.distribution_type == 'logistic':
log_px = log_logistic(x, *px)
elif args.distribution_type == 'normal':
log_px = log_normal(x, *px)
elif args.distribution_type == 'steplogistic':
x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width
log_px = log_discretized_logistic(
x, *px, inverse_bin_width=inverse_bin_width)
log_pxs.append(
torch.sum(log_px, dim=[1, 2, 3]))
return log_pxs
def compute_log_pz(pz, z, args):
inverse_bin_width = 2.**args.n_bits
if args.variable_type == 'discrete':
if args.distribution_type == 'logistic':
if args.n_mixtures == 1:
log_pz = log_discretized_logistic(
z, pz[0], pz[1], inverse_bin_width=inverse_bin_width)
else:
log_pz = log_mixture_discretized_logistic(
z, pz[0], pz[1], pz[2],
inverse_bin_width=inverse_bin_width)
elif args.distribution_type == 'normal':
log_pz = log_discretized_normal(
z, *pz, inverse_bin_width=inverse_bin_width)
elif args.variable_type == 'continuous':
if args.distribution_type == 'logistic':
log_pz = log_logistic(z, *pz)
elif args.distribution_type == 'normal':
if args.n_mixtures == 1:
log_pz = log_normal(z, *pz)
else:
log_pz = log_mixture_normal(z, *pz)
elif args.distribution_type == 'steplogistic':
z = _round_straightthrough(z * 256.) / 256.
log_pz = log_discretized_logistic(z, *pz)
log_pz = torch.sum(
log_pz,
dim=[1, 2, 3])
return log_pz
def compute_loss_function(pz, z, pys, ys, ldj, args):
"""
Computes the cross entropy loss function while summing over batch dimension, not averaged!
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
:param z_mu: mean of z_0
:param z_var: variance of z_0
:param z_0: first stochastic latent variable
:param z_k: last stochastic latent variable
:param ldj: log det jacobian
:param args: global parameter settings
:param beta: beta for kl loss
:return: loss, ce, kl
"""
batch_size = z.size(0)
# Get array loss, sum over batch
loss_array, bpd_array, bpd_per_prior_array = \
compute_loss_array(pz, z, pys, ys, ldj, args)
loss = torch.mean(loss_array)
bpd = torch.mean(bpd_array).item()
bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array]
return loss, bpd, bpd_per_prior
def convert_bpd(log_p, input_size):
return -log_p / (np.prod(input_size) * np.log(2.))
def compute_loss_array(pz, z, pys, ys, ldj, args):
"""
Computes the cross entropy loss function while summing over batch dimension, not averaged!
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
:param z_mu: mean of z_0
:param z_var: variance of z_0
:param z_0: first stochastic latent variable
:param z_k: last stochastic latent variable
:param ldj: log det jacobian
:param args: global parameter settings
:param beta: beta for kl loss
:return: loss, ce, kl
"""
bpd_per_prior = []
# Likelihood of final representation.
log_pz = compute_log_pz(pz, z, args)
bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size))
log_p = log_pz
# Add likelihoods of intermediate representations.
if ys:
log_pys = compute_log_ps(pys, ys, args)
for log_py in log_pys:
log_p += log_py
bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size))
log_p += ldj
loss = -log_p
bpd = convert_bpd(log_p.detach(), args.input_size)
return loss, bpd, bpd_per_prior
def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args):
return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)

View file

@ -0,0 +1,174 @@
from __future__ import print_function
import torch
from optimization.loss import calculate_loss
from utils.visual_evaluation import plot_reconstructions
import numpy as np
def train(epoch, train_loader, model, opt, args):
model.train()
train_loss = np.zeros(len(train_loader))
train_bpd = np.zeros(len(train_loader))
num_data = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, *args.input_size)
data = data.to(args.device)
opt.zero_grad()
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
loss = torch.mean(loss)
bpd = torch.mean(bpd)
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
loss.backward()
loss = loss.item()
train_loss[batch_idx] = loss
train_bpd[batch_idx] = bpd
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
opt.step()
num_data += len(data)
if batch_idx % args.log_interval == 0:
perc = 100. * batch_idx / len(train_loader)
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
for i in range(1, len(bpd_per_prior)):
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
if len(pz) == 3:
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
for i, py in enumerate(pys):
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
from utils.visual_evaluation import plot_images
import os
if not os.path.exists(args.snap_dir + 'training/'):
os.makedirs(args.snap_dir + 'training/')
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
return train_loss, train_bpd
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
model.eval()
loss_type = 'bpd'
def analyse(data_loader, plot=False):
bpds = []
batch_idx = 0
with torch.no_grad():
for data, _ in data_loader:
batch_idx += 1
if args.cuda:
data = data.cuda()
data = data.view(-1, *args.input_size)
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
model(data)
loss = torch.mean(loss).item()
batch_bpd = torch.mean(batch_bpd).item()
bpds.append(batch_bpd)
bpd = np.mean(bpds)
with torch.no_grad():
if not testing and plot:
x_sample = model_sample.sample(n=100)
try:
plot_reconstructions(
x_sample, bpd, loss_type, epoch, args)
except:
print('Not plotting')
return bpd
bpd_train = analyse(train_loader)
bpd_val = analyse(val_loader, plot=True)
with open(file, 'a') as ff:
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
epoch,
bpd_train,
bpd_val)
print(msg, file=ff)
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
bpd = bpd_val
file = None
# Compute log-likelihood
with torch.no_grad():
if testing:
test_data = val_loader.dataset.data_tensor
if args.cuda:
test_data = test_data.cuda()
print('Computing log-likelihood on test set')
model.eval()
log_likelihood = analyse(test_data)
else:
log_likelihood = None
nll_bpd = None
if file is None:
if testing:
print('====> Test set loss: {:.4f}'.format(loss))
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
(np.prod(args.input_size) * np.log(2.))))
else:
print('====> Validation set loss: {:.4f}'.format(loss))
print('====> Validation set bpd: {:.4f}'.format(bpd))
else:
with open(file, 'a') as ff:
if testing:
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
(np.prod(args.input_size) * np.log(2.))),
file=ff)
else:
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
file=ff)
if not testing:
return loss, bpd
else:
return log_likelihood, nll_bpd

View file

View file

@ -0,0 +1,209 @@
from __future__ import print_function
import torch
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import math
MIN_EPSILON = 1e-5
MAX_EPSILON = 1.-1e-5
PI = math.pi
def log_min_exp(a, b, epsilon=1e-8):
"""
Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.
Using:
log(exp(a) - exp(b))
c + log(exp(a-c) - exp(b-c))
a + log(1 - exp(b-a))
And note that we assume b < a always.
"""
y = a + torch.log(1 - torch.exp(b - a) + epsilon)
return y
def log_normal(x, mean, logvar):
logp = -0.5 * logvar
logp += -0.5 * np.log(2 * PI)
logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar)
return logp
def log_mixture_normal(x, mean, logvar, pi):
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
logp_mixtures = log_normal(x, mean, logvar)
logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8)
return logp
def sample_normal(mean, logvar):
y = torch.randn_like(mean)
x = torch.exp(0.5 * logvar) * y + mean
return x
def sample_mixture_normal(mean, logvar, pi):
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
pi = pi.view(b * c * h * w, n_mixtures)
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
# Select mixture params
mean = mean.view(b * c * h * w, n_mixtures)
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
logvar = logvar.view(b * c * h * w, n_mixtures)
logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
y = sample_normal(mean, logvar)
return y
def log_logistic(x, mean, logscale):
"""
pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale
"""
scale = torch.exp(logscale)
u = (x - mean) / scale
logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale
return logp
def sample_logistic(mean, logscale):
y = torch.rand_like(mean)
x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean
return x
def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
scale = torch.exp(logscale)
logp = log_min_exp(
F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
return logp
def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width):
scale = torch.exp(logscale)
cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
return cdf
def sample_discretized_logistic(mean, logscale, inverse_bin_width):
x = sample_logistic(mean, logscale)
x = torch.round(x * inverse_bin_width) / inverse_bin_width
return x
def normal_cdf(value, loc, std):
return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2)))
def log_discretized_normal(x, mean, logvar, inverse_bin_width):
std = torch.exp(0.5 * logvar)
log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7)
return log_p
def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width):
std = torch.exp(0.5 * logvar)
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std)
p = torch.sum(p * pi, dim=-1)
logp = torch.log(p + 1e-8)
return logp
def sample_discretized_normal(mean, logvar, inverse_bin_width):
y = torch.randn_like(mean)
x = torch.exp(0.5 * logvar) * y + mean
x = torch.round(x * inverse_bin_width) / inverse_bin_width
return x
def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width):
scale = torch.exp(logscale)
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \
- torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale)
p = torch.sum(p * pi, dim=-1)
logp = torch.log(p + 1e-8)
return logp
def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width):
scale = torch.exp(logscale)
x = x[..., None]
cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
cdf = torch.sum(cdfs * pi, dim=-1)
return cdf
def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width):
# Sample mixtures
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
pi = pi.view(b * c * h * w, n_mixtures)
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
# Select mixture params
mean = mean.view(b * c * h * w, n_mixtures)
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
logs = logs.view(b * c * h * w, n_mixtures)
logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
y = torch.rand_like(mean)
x = torch.exp(logs) * torch.log(y / (1 - y)) + mean
x = torch.round(x * inverse_bin_width) / inverse_bin_width
return x
def log_multinomial(logits, targets):
return -F.cross_entropy(logits, targets, reduction='none')
def sample_multinomial(logits):
b, n_categories, c, h, w = logits.size()
logits = logits.permute(0, 2, 3, 4, 1)
p = F.softmax(logits, dim=-1)
p = p.view(b * c * h * w, n_categories)
x = torch.multinomial(p, num_samples=1).view(b, c, h, w)
return x

View file

@ -0,0 +1,264 @@
from __future__ import print_function
import numbers
import torch
import torch.utils.data as data_utils
import pickle
from scipy.io import loadmat
import numpy as np
import os
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as vf
from torch.utils.data import ConcatDataset
from PIL import Image
import os
import os.path
from os.path import join
import sys
import tarfile
class ToTensorNoNorm():
def __call__(self, X_i):
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
class PadToMultiple(object):
def __init__(self, multiple, fill=0, padding_mode='constant'):
assert isinstance(multiple, numbers.Number)
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.multiple = multiple
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be padded.
Returns:
PIL Image: Padded image.
"""
w, h = img.size
m = self.multiple
nw = (w // m + int((w % m) != 0)) * m
nh = (h // m + int((h % m) != 0)) * m
padw = nw - w
padh = nh - h
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
return out
def __repr__(self):
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
format(self.mulitple, self.fill, self.padding_mode)
class CustomTensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors, transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self, index):
from PIL import Image
X, y = self.tensors
X_i, y_i, = X[index], y[index]
if self.transform:
X_i = self.transform(X_i)
X_i = torch.from_numpy(np.array(X_i, copy=False))
X_i = X_i.permute(2, 0, 1)
return X_i, y_i
def __len__(self):
return self.tensors[0].size(0)
def load_cifar10(args, **kwargs):
# set args
args.input_size = [3, 32, 32]
args.input_type = 'continuous'
args.dynamic_binarization = False
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.transpose(0, 3, 1, 2)
x_test = x_test.transpose(0, 3, 1, 2)
import math
if args.data_augmentation_level == 2:
data_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.CenterCrop(32)
])
elif args.data_augmentation_level == 1:
data_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
])
else:
data_transform = transforms.Compose([
transforms.ToPILImage(),
])
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, args
def extract_tar(tarpath):
assert tarpath.endswith('.tar')
startdir = tarpath[:-4] + '/'
if os.path.exists(startdir):
return startdir
print('Extracting', tarpath)
with tarfile.open(name=tarpath) as tar:
t = 0
done = False
while not done:
path = join(startdir, 'images{}'.format(t))
os.makedirs(path, exist_ok=True)
print(path)
for i in range(50000):
member = tar.next()
if member is None:
done = True
break
# Skip directories
while member.isdir():
member = tar.next()
if member is None:
done = True
break
member.name = member.name.split('/')[-1]
tar.extract(member, path=path)
t += 1
return startdir
def load_imagenet(resolution, args, **kwargs):
assert resolution == 32 or resolution == 64
args.input_size = [3, resolution, resolution]
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
trainpath = extract_tar(trainpath)
valpath = extract_tar(valpath)
data_transform = transforms.Compose([
ToTensorNoNorm()
])
print('Starting loading ImageNet')
imagenet_data = torchvision.datasets.ImageFolder(
trainpath,
transform=data_transform)
print('Number of data images', len(imagenet_data))
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
train_dataset = torch.utils.data.dataset.Subset(
imagenet_data, train_idcs)
val_dataset = torch.utils.data.dataset.Subset(
imagenet_data, val_idcs)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
**kwargs)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
**kwargs)
test_dataset = torchvision.datasets.ImageFolder(
valpath,
transform=data_transform)
print('Number of val images:', len(test_dataset))
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
**kwargs)
return train_loader, val_loader, test_loader, args
def load_dataset(args, **kwargs):
if args.dataset == 'cifar10':
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
elif args.dataset == 'imagenet32':
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
elif args.dataset == 'imagenet64':
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
else:
raise Exception('Wrong name of the dataset!')
return train_loader, val_loader, test_loader, args
if __name__ == '__main__':
class Args():
def __init__(self):
self.batch_size = 128
train_loader, val_loader, test_loader, args = load_imagenet32(Args())

View file

@ -0,0 +1,56 @@
from __future__ import print_function
import numpy as np
from scipy.misc import logsumexp
from optimization.loss import calculate_loss_array
def calculate_likelihood(X, model, args, S=5000, MB=500):
# set auxiliary variables for number of training and test sets
N_test = X.size(0)
X = X.view(-1, *args.input_size)
likelihood_test = []
if S <= MB:
R = 1
else:
R = S // MB
S = MB
for j in range(N_test):
if j % 100 == 0:
print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100))
x_single = X[j].unsqueeze(0)
a = []
for r in range(0, R):
# Repeat it for all training points
x = x_single.expand(S, *x_single.size()[1:]).contiguous()
x_mean, z_mu, z_var, ldj, z0, zk = model(x)
a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args)
a.append(-a_tmp.cpu().data.numpy())
# calculate max
a = np.asarray(a)
a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
likelihood_x = logsumexp(a)
likelihood_test.append(likelihood_x - np.log(len(a)))
likelihood_test = np.array(likelihood_test)
nll = -np.mean(likelihood_test)
if args.input_type == 'multinomial':
bpd = nll/(np.prod(args.input_size) * np.log(2.))
elif args.input_type == 'binary':
bpd = 0.
else:
raise ValueError('invalid input type!')
return nll, bpd

View file

@ -0,0 +1,104 @@
from __future__ import division
from __future__ import print_function
import numpy as np
import matplotlib
# noninteractive background
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None):
"""
Plots train_loss and validation loss as a function of optimization iteration
:param train_loss: np.array of train_loss (1D or 2D)
:param validation_loss: np.array of validation loss (1D or 2D)
:param fname: output file name
:param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied
accross training curves.
:return: None
"""
plt.close()
matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
if len(train_loss.shape) == 1:
# Single training curve
fig, ax = plt.subplots(nrows=1, ncols=1)
figsize = (6, 4)
if train_loss.shape[0] == validation_loss.shape[0]:
# validation score evaluated every iteration
x = np.arange(train_loss.shape[0])
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
elif train_loss.shape[0] % validation_loss.shape[0] == 0:
# validation score evaluated every epoch
x = np.arange(train_loss.shape[0])
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
x = np.arange(validation_loss.shape[0])
x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0]
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
else:
raise ValueError('Length of train_loss and validation_loss must be equal or divisible')
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
ax.set_ylim([miny, maxy])
elif len(train_loss.shape) == 2:
# Multiple training curves
cmap = plt.cm.brg
cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0])
scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
fig, ax = plt.subplots(nrows=1, ncols=1)
figsize = (6, 4)
if labels is None:
labels = ['%d' % i for i in range(train_loss.shape[0])]
if train_loss.shape[1] == validation_loss.shape[1]:
for i in range(train_loss.shape[0]):
color_val = scalarMap.to_rgba(i)
# validation score evaluated every iteration
x = np.arange(train_loss.shape[0])
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
ax.plot(x, validation_loss[i], '--', lw=2., color=color_val)
elif train_loss.shape[1] % validation_loss.shape[1] == 0:
for i in range(train_loss.shape[0]):
color_val = scalarMap.to_rgba(i)
# validation score evaluated every epoch
x = np.arange(train_loss.shape[1])
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
x = np.arange(validation_loss.shape[1])
x = (x+1) * train_loss.shape[1] / validation_loss.shape[1]
ax.plot(x, validation_loss[i], '-', lw=2., color=color_val)
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
ax.set_ylim([miny, maxy])
else:
raise ValueError('train_loss and validation_loss must be 1D or 2D arrays')
ax.set_xlabel('iteration')
ax.set_ylabel('loss')
plt.title('Training and validation loss')
fig.set_size_inches(figsize)
fig.subplots_adjust(hspace=0.1)
plt.savefig(fname, bbox_inches='tight')
plt.close()

View file

@ -0,0 +1,37 @@
from __future__ import print_function
import os
import numpy as np
import imageio
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
if epoch == 1:
if not os.path.exists(args.snap_dir + 'reconstruction/'):
os.makedirs(args.snap_dir + 'reconstruction/')
if loss_type == 'bpd':
fname = str(epoch) + '_bpd_%5.3f' % loss
elif loss_type == 'elbo':
fname = str(epoch) + '_elbo_%6.4f' % loss
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
batch, channels, height, width = x_sample.shape
print(x_sample.shape)
mosaic = np.zeros((height * size_y, width * size_x, channels))
for j in range(size_y):
for i in range(size_x):
idx = j * size_x + i
image = x_sample[idx]
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
image.transpose(1, 2, 0)
# Remove channel for BW images
mosaic = mosaic.squeeze()
imageio.imwrite(dir + file_name + '.png', mosaic)