Merge branch 'main' into portability
This commit is contained in:
commit
013afe832f
44 changed files with 5 additions and 2834 deletions
|
|
@ -1,19 +0,0 @@
|
||||||
Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
# Integer Discrete Flows and Lossless Compression
|
|
||||||
|
|
||||||
This repository contains the code for the experiments presented in [1].
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### CIFAR10 setup:
|
|
||||||
```
|
|
||||||
python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type ‘densenet’
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### ImageNet32 setup:
|
|
||||||
```
|
|
||||||
python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### ImageNet64 setup:
|
|
||||||
```
|
|
||||||
python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64
|
|
||||||
```
|
|
||||||
|
|
||||||
# Acknowledgements
|
|
||||||
The Robert Bosch GmbH is acknowledged for financial support.
|
|
||||||
|
|
||||||
# References
|
|
||||||
[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019).
|
|
||||||
|
|
||||||
|
|
@ -1,132 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from . import rans
|
|
||||||
from utils.distributions import discretized_logistic_cdf, \
|
|
||||||
mixture_discretized_logistic_cdf
|
|
||||||
import torch
|
|
||||||
|
|
||||||
precision = 24
|
|
||||||
n_bins = 4096
|
|
||||||
|
|
||||||
|
|
||||||
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
|
|
||||||
if variable_type == 'discrete':
|
|
||||||
if distribution_type == 'logistic':
|
|
||||||
if len(pz) == 2:
|
|
||||||
return discretized_logistic_cdf(
|
|
||||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
|
||||||
elif len(pz) == 3:
|
|
||||||
return mixture_discretized_logistic_cdf(
|
|
||||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
|
||||||
elif distribution_type == 'normal':
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif variable_type == 'continuous':
|
|
||||||
if distribution_type == 'logistic':
|
|
||||||
pass
|
|
||||||
elif distribution_type == 'normal':
|
|
||||||
pass
|
|
||||||
elif distribution_type == 'steplogistic':
|
|
||||||
pass
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
|
||||||
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
|
|
||||||
MEAN = torch.round(mean / bin_width).long()
|
|
||||||
|
|
||||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
|
||||||
bin_locations = bin_locations.float() * bin_width
|
|
||||||
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
|
||||||
|
|
||||||
pz = [param[:, :, :, :, None] for param in pz]
|
|
||||||
cdf = cdf_fn(
|
|
||||||
bin_locations - bin_width,
|
|
||||||
pz,
|
|
||||||
variable_type,
|
|
||||||
distribution_type,
|
|
||||||
1./bin_width).cpu().numpy()
|
|
||||||
|
|
||||||
# Compute CDFs, reweigh to give all bins at least
|
|
||||||
# 1 / (2^precision) probability.
|
|
||||||
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
|
|
||||||
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
|
|
||||||
+ np.arange(n_bins)
|
|
||||||
|
|
||||||
return CDFs, MEAN
|
|
||||||
|
|
||||||
|
|
||||||
def encode_sample(
|
|
||||||
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
|
|
||||||
if state is None:
|
|
||||||
state = rans.x_init
|
|
||||||
else:
|
|
||||||
state = rans.unflatten(state)
|
|
||||||
|
|
||||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
|
||||||
|
|
||||||
# z is transformed to Z to match the indices for the CDFs array
|
|
||||||
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
|
|
||||||
Z = Z.cpu().numpy()
|
|
||||||
|
|
||||||
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
|
|
||||||
print('Z out of allowed range of values, canceling compression')
|
|
||||||
return None
|
|
||||||
|
|
||||||
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
|
|
||||||
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
|
|
||||||
statfun = statfun_encode(cdf)
|
|
||||||
state = rans.append_symbol(statfun, precision)(state, symbol)
|
|
||||||
|
|
||||||
state = rans.flatten(state)
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def decode_sample(
|
|
||||||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
|
||||||
state = rans.unflatten(state)
|
|
||||||
|
|
||||||
device = pz[0].DEVICE
|
|
||||||
size = pz[0].size()[0:4]
|
|
||||||
|
|
||||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
|
||||||
|
|
||||||
CDFs = CDFs.reshape(-1, n_bins)
|
|
||||||
result = np.zeros(len(CDFs), dtype=int)
|
|
||||||
for i, cdf in enumerate(CDFs):
|
|
||||||
statfun = statfun_decode(cdf)
|
|
||||||
state, symbol = rans.pop_symbol(statfun, precision)(state)
|
|
||||||
result[i] = symbol
|
|
||||||
|
|
||||||
Z_flat = torch.from_numpy(result).to(device)
|
|
||||||
Z = Z_flat.view(size) - n_bins // 2 + MEAN
|
|
||||||
|
|
||||||
z = Z.float() * bin_width
|
|
||||||
|
|
||||||
state = rans.flatten(state)
|
|
||||||
|
|
||||||
return state, z
|
|
||||||
|
|
||||||
|
|
||||||
def statfun_encode(CDF):
|
|
||||||
def _statfun_encode(symbol):
|
|
||||||
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
|
|
||||||
return _statfun_encode
|
|
||||||
|
|
||||||
|
|
||||||
def statfun_decode(CDF):
|
|
||||||
def _statfun_decode(cf):
|
|
||||||
# Search such that CDF[s] <= cf < CDF[s]
|
|
||||||
s = np.searchsorted(CDF, cf, side='right')
|
|
||||||
s = s - 1
|
|
||||||
start, freq = statfun_encode(CDF)(s)
|
|
||||||
return s, (start, freq)
|
|
||||||
return _statfun_decode
|
|
||||||
|
|
||||||
|
|
||||||
def encode(x, symbol):
|
|
||||||
return rans.append_symbol(statfun_encode, precision)(x, symbol)
|
|
||||||
|
|
||||||
|
|
||||||
def decode(x):
|
|
||||||
return rans.pop_symbol(statfun_decode, precision)(x)
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
"""
|
|
||||||
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h
|
|
||||||
|
|
||||||
ROUGH GUIDE:
|
|
||||||
We use the pythonic names 'append' and 'pop' for encoding and decoding
|
|
||||||
respectively. The compressed state 'x' is an immutable stack, implemented using
|
|
||||||
a cons list.
|
|
||||||
|
|
||||||
x: the current stack-like state of the encoder/decoder.
|
|
||||||
|
|
||||||
precision: the natural numbers are divided into ranges of size 2^precision.
|
|
||||||
|
|
||||||
start & freq: start indicates the beginning of the range in [0, 2^precision-1]
|
|
||||||
that the current symbol is represented by. freq is the length of the range.
|
|
||||||
freq is chosen such that p(symbol) ~= freq/2^precision.
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
|
|
||||||
rans_l = 1 << 31 # the lower bound of the normalisation interval
|
|
||||||
tail_bits = (1 << 32) - 1
|
|
||||||
|
|
||||||
x_init = (rans_l, ())
|
|
||||||
|
|
||||||
def append(x, start, freq, precision):
|
|
||||||
"""Encodes a symbol with range [start, start + freq). All frequencies are
|
|
||||||
assumed to sum to "1 << precision", and the resulting bits get written to
|
|
||||||
x."""
|
|
||||||
if x[0] >= ((rans_l >> precision) << 32) * freq:
|
|
||||||
x = (x[0] >> 32, (x[0] & tail_bits, x[1]))
|
|
||||||
return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1]
|
|
||||||
|
|
||||||
def pop(x_, precision):
|
|
||||||
"""Advances in the bit stream by "popping" a single symbol with range start
|
|
||||||
"start" and frequency "freq"."""
|
|
||||||
cf = x_[0] & ((1 << precision) - 1)
|
|
||||||
def pop(start, freq):
|
|
||||||
x = freq * (x_[0] >> precision) + cf - start, x_[1]
|
|
||||||
return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x
|
|
||||||
return cf, pop
|
|
||||||
|
|
||||||
def append_symbol(statfun, precision):
|
|
||||||
def append_(x, symbol):
|
|
||||||
start, freq = statfun(symbol)
|
|
||||||
return append(x, start, freq, precision)
|
|
||||||
return append_
|
|
||||||
|
|
||||||
def pop_symbol(statfun, precision):
|
|
||||||
def pop_(x):
|
|
||||||
cf, pop_fun = pop(x, precision)
|
|
||||||
symbol, (start, freq) = statfun(cf)
|
|
||||||
return pop_fun(start, freq), symbol
|
|
||||||
return pop_
|
|
||||||
|
|
||||||
def flatten(x):
|
|
||||||
"""Flatten a rans state x into a 1d numpy array."""
|
|
||||||
out, x = [x[0] >> 32, x[0]], x[1]
|
|
||||||
while x:
|
|
||||||
x_head, x = x
|
|
||||||
out.append(x_head)
|
|
||||||
return np.asarray(out, dtype=np.uint32)
|
|
||||||
|
|
||||||
def unflatten(arr):
|
|
||||||
"""Unflatten a 1d numpy array into a rans state."""
|
|
||||||
return (int(arr[0]) << 32 | int(arr[1]),
|
|
||||||
reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))
|
|
||||||
|
|
@ -1,188 +0,0 @@
|
||||||
# !/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from utils.load_data import load_dataset
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
|
||||||
|
|
||||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
|
||||||
metavar='DATASET',
|
|
||||||
help='Dataset choice.')
|
|
||||||
|
|
||||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
|
||||||
help='disables CUDA training')
|
|
||||||
|
|
||||||
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
|
||||||
|
|
||||||
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
|
||||||
help='how many batches to wait before logging training status')
|
|
||||||
|
|
||||||
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
|
||||||
help='Evaluate per how many epochs')
|
|
||||||
|
|
||||||
|
|
||||||
# optimization settings
|
|
||||||
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
|
||||||
help='number of epochs to train (default: 2000)')
|
|
||||||
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
|
||||||
help='number of early stopping epochs')
|
|
||||||
|
|
||||||
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
|
|
||||||
help='input batch size for training (default: 100)')
|
|
||||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
|
||||||
help='learning rate')
|
|
||||||
parser.add_argument('--warmup', type=int, default=10,
|
|
||||||
help='number of warmup epochs')
|
|
||||||
|
|
||||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
|
||||||
help='data augmentation level')
|
|
||||||
|
|
||||||
parser.add_argument('--no_decode', action='store_true', default=False,
|
|
||||||
help='disables decoding')
|
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
|
||||||
|
|
||||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
|
||||||
|
|
||||||
|
|
||||||
def encode_images(img, model, decode):
|
|
||||||
batchsize, img_c, img_h, img_w = img.size()
|
|
||||||
c, h, w = model.args.input_size
|
|
||||||
|
|
||||||
assert img_h == img_w and h == w
|
|
||||||
|
|
||||||
if img_h != h:
|
|
||||||
assert img_h % h == 0
|
|
||||||
steps = img_h // h
|
|
||||||
|
|
||||||
states = [[] for i in range(batchsize)]
|
|
||||||
state_sizes = [0 for i in range(batchsize)]
|
|
||||||
bpd = [0 for i in range(batchsize)]
|
|
||||||
error = 0
|
|
||||||
|
|
||||||
for j in range(steps):
|
|
||||||
for i in range(steps):
|
|
||||||
r = encode_patches(
|
|
||||||
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
|
|
||||||
for b in range(batchsize):
|
|
||||||
|
|
||||||
if r[0][b] is None:
|
|
||||||
states[b].append(None)
|
|
||||||
else:
|
|
||||||
states[b].extend(r[0][b])
|
|
||||||
state_sizes[b] += r[1][b]
|
|
||||||
bpd[b] += r[2][b] / steps**2
|
|
||||||
error += r[3]
|
|
||||||
return states, state_sizes, bpd, error
|
|
||||||
else:
|
|
||||||
return encode_patches(img, model, decode)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_patches(imgs, model, decode):
|
|
||||||
batchsize, img_c, img_h, img_w = imgs.size()
|
|
||||||
c, h, w = model.args.input_size
|
|
||||||
assert img_h == h and img_w == w
|
|
||||||
|
|
||||||
states = model.encode(imgs)
|
|
||||||
|
|
||||||
bpd = model.forward(imgs)[1].cpu().numpy()
|
|
||||||
|
|
||||||
state_sizes = []
|
|
||||||
error = 0
|
|
||||||
|
|
||||||
for b in range(batchsize):
|
|
||||||
if states[b] is None:
|
|
||||||
# Using escape bit ;)
|
|
||||||
state_sizes += [8 * img_c * img_h * img_w + 1]
|
|
||||||
|
|
||||||
# Error remains unchanged.
|
|
||||||
print('Escaping, not encoding.')
|
|
||||||
|
|
||||||
else:
|
|
||||||
if decode:
|
|
||||||
x_recon = model.decode([states[b]])
|
|
||||||
|
|
||||||
error += torch.sum(
|
|
||||||
torch.abs(x_recon.int() - imgs[b].int())).item()
|
|
||||||
|
|
||||||
# Append state plus an escape bit
|
|
||||||
state_sizes += [32 * len(states[b]) + 1]
|
|
||||||
|
|
||||||
return states, state_sizes, bpd, error
|
|
||||||
|
|
||||||
|
|
||||||
def run(args, kwargs):
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
|
|
||||||
args.snap_dir = snap_dir = \
|
|
||||||
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# SNAPSHOTS
|
|
||||||
# ==================================================================================================================
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# LOAD DATA
|
|
||||||
# ==================================================================================================================
|
|
||||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
|
||||||
|
|
||||||
final_model = torch.load(snap_dir + 'a.model')
|
|
||||||
if hasattr(final_model, 'module'):
|
|
||||||
final_model = final_model.module
|
|
||||||
final_model = final_model.cuda()
|
|
||||||
|
|
||||||
sizes = []
|
|
||||||
errors = []
|
|
||||||
bpds = []
|
|
||||||
|
|
||||||
import time
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
t = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for data, _ in test_loader:
|
|
||||||
if args.cuda:
|
|
||||||
data = data.cuda()
|
|
||||||
|
|
||||||
state, state_sizes, bpd, error = \
|
|
||||||
encode_images(data, final_model, decode=not args.no_decode)
|
|
||||||
|
|
||||||
errors += [error]
|
|
||||||
bpds.extend(bpd)
|
|
||||||
sizes.extend(state_sizes)
|
|
||||||
|
|
||||||
t += len(data)
|
|
||||||
|
|
||||||
print(
|
|
||||||
'Examples: {}/{} bpd compression: {:.3f} error: {},'
|
|
||||||
' analytical bpd {:.3f}'.format(
|
|
||||||
t, len(test_loader.dataset),
|
|
||||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
|
||||||
np.sum(errors),
|
|
||||||
np.mean(bpds)
|
|
||||||
))
|
|
||||||
|
|
||||||
if args.no_decode:
|
|
||||||
print('Not testing decoding.')
|
|
||||||
else:
|
|
||||||
print('Error: {}'.format(np.sum(errors)))
|
|
||||||
|
|
||||||
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
|
|
||||||
print('Final bpd: {:.3f} error: {}'.format(
|
|
||||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
|
||||||
np.sum(errors)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
run(args, kwargs)
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
# !/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import torch.optim as optim
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
|
|
||||||
import os
|
|
||||||
from optimization.training import train, evaluate
|
|
||||||
from utils.load_data import load_dataset
|
|
||||||
from utils.plotting import plot_training_curve
|
|
||||||
import imageio
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
|
||||||
|
|
||||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
|
||||||
metavar='DATASET',
|
|
||||||
help='Dataset choice.')
|
|
||||||
|
|
||||||
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
|
||||||
help='input batch size for training (default: 100)')
|
|
||||||
|
|
||||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
|
||||||
help='data augmentation level')
|
|
||||||
|
|
||||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
|
||||||
help='disables CUDA training')
|
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
|
||||||
|
|
||||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
|
||||||
|
|
||||||
|
|
||||||
def run(args, kwargs):
|
|
||||||
|
|
||||||
args.snap_dir = snap_dir = \
|
|
||||||
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# SNAPSHOTS
|
|
||||||
# ==================================================================================================================
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# LOAD DATA
|
|
||||||
# ==================================================================================================================
|
|
||||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
|
||||||
|
|
||||||
final_model = torch.load(snap_dir + 'a.model')
|
|
||||||
if hasattr(final_model, 'module'):
|
|
||||||
final_model = final_model.module
|
|
||||||
|
|
||||||
from models.backround import SmoothRound
|
|
||||||
for module in final_model.modules():
|
|
||||||
if isinstance(module, SmoothRound):
|
|
||||||
module._round_decay = 1.
|
|
||||||
|
|
||||||
exp_dir = snap_dir + 'partials/'
|
|
||||||
os.makedirs(exp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for data, _ in test_loader:
|
|
||||||
|
|
||||||
if args.cuda:
|
|
||||||
data = data.cuda()
|
|
||||||
|
|
||||||
for i in range(len(data)):
|
|
||||||
_, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1])
|
|
||||||
|
|
||||||
for j in range(len(ys) + 1):
|
|
||||||
x_recon = final_model.inverse(
|
|
||||||
z,
|
|
||||||
ys[len(ys) - j:])
|
|
||||||
|
|
||||||
images.append(x_recon.float())
|
|
||||||
|
|
||||||
if i == 10:
|
|
||||||
break
|
|
||||||
break
|
|
||||||
|
|
||||||
for j in range(len(ys) + 1):
|
|
||||||
|
|
||||||
grid = make_grid(
|
|
||||||
torch.stack(images[j::len(ys) + 1], dim=0).squeeze(),
|
|
||||||
nrow=11, padding=0,
|
|
||||||
normalize=True, range=None,
|
|
||||||
scale_each=False, pad_value=0)
|
|
||||||
|
|
||||||
imageio.imwrite(
|
|
||||||
exp_dir + 'loaded{j}.png'.format(j=j),
|
|
||||||
grid.cpu().numpy().transpose(1, 2, 0))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
run(args, kwargs)
|
|
||||||
|
|
@ -1,285 +0,0 @@
|
||||||
# !/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import torch.optim as optim
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from optimization.training import train, evaluate
|
|
||||||
from utils.load_data import load_dataset
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
|
||||||
|
|
||||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10',
|
|
||||||
choices=['cifar10', 'imagenet32', 'imagenet64'],
|
|
||||||
metavar='DATASET',
|
|
||||||
help='Dataset choice.')
|
|
||||||
|
|
||||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
|
||||||
help='disables CUDA training')
|
|
||||||
|
|
||||||
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
|
||||||
|
|
||||||
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
|
||||||
help='how many batches to wait before logging training status')
|
|
||||||
|
|
||||||
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
|
||||||
help='Evaluate per how many epochs')
|
|
||||||
|
|
||||||
parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
|
|
||||||
help='output directory for model snapshots etc.')
|
|
||||||
|
|
||||||
fp = parser.add_mutually_exclusive_group(required=False)
|
|
||||||
fp.add_argument('-te', '--testing', action='store_true', dest='testing',
|
|
||||||
help='evaluate on test set after training')
|
|
||||||
fp.add_argument('-va', '--validation', action='store_false', dest='testing',
|
|
||||||
help='only evaluate on validation set')
|
|
||||||
parser.set_defaults(testing=True)
|
|
||||||
|
|
||||||
# optimization settings
|
|
||||||
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
|
||||||
help='number of epochs to train (default: 2000)')
|
|
||||||
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
|
||||||
help='number of early stopping epochs')
|
|
||||||
|
|
||||||
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
|
||||||
help='input batch size for training (default: 100)')
|
|
||||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
|
||||||
help='learning rate')
|
|
||||||
parser.add_argument('--warmup', type=int, default=10,
|
|
||||||
help='number of warmup epochs')
|
|
||||||
|
|
||||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
|
||||||
help='data augmentation level')
|
|
||||||
|
|
||||||
parser.add_argument('--variable_type', type=str, default='discrete',
|
|
||||||
help='variable type of data distribution: discrete/continuous',
|
|
||||||
choices=['discrete', 'continuous'])
|
|
||||||
parser.add_argument('--distribution_type', type=str, default='logistic',
|
|
||||||
choices=['logistic', 'normal', 'steplogistic'],
|
|
||||||
help='distribution type: logistic/normal')
|
|
||||||
parser.add_argument('--n_flows', type=int, default=8,
|
|
||||||
help='number of flows per level')
|
|
||||||
parser.add_argument('--n_levels', type=int, default=3,
|
|
||||||
help='number of levels')
|
|
||||||
|
|
||||||
parser.add_argument('--n_bits', type=int, default=8,
|
|
||||||
help='')
|
|
||||||
|
|
||||||
# ---------------- SETTINGS CONCERNING NETWORKS -------------
|
|
||||||
parser.add_argument('--densenet_depth', type=int, default=8,
|
|
||||||
help='Depth of densenets')
|
|
||||||
parser.add_argument('--n_channels', type=int, default=512,
|
|
||||||
help='number of channels in coupling and splitprior')
|
|
||||||
# ---------------- ----------------------------- -------------
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
|
|
||||||
parser.add_argument('--coupling_type', type=str, default='shallow',
|
|
||||||
choices=['shallow', 'resnet', 'densenet'],
|
|
||||||
help='Type of coupling layer')
|
|
||||||
parser.add_argument('--splitfactor', default=0, type=int,
|
|
||||||
help='Split factor for coupling layers.')
|
|
||||||
|
|
||||||
parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
|
|
||||||
help='Split coupling layer on quarter')
|
|
||||||
parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
|
|
||||||
parser.set_defaults(split_quarter=True)
|
|
||||||
# ---------------- ----------------------------------- -------------
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
|
|
||||||
parser.add_argument('--splitprior_type', type=str, default='shallow',
|
|
||||||
choices=['none', 'shallow', 'resnet', 'densenet'],
|
|
||||||
help='Type of splitprior. Use \'none\' for no splitprior')
|
|
||||||
# ---------------- ------------------------------- -------------
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------- SETTINGS CONCERNING PRIORS -------------
|
|
||||||
parser.add_argument('--n_mixtures', type=int, default=1,
|
|
||||||
help='number of mixtures')
|
|
||||||
# ---------------- ------------------------------- -------------
|
|
||||||
|
|
||||||
parser.add_argument('--hard_round', dest='hard_round', action='store_true',
|
|
||||||
help='Rounding of translation in discrete models. Weird '
|
|
||||||
'probabilistic implications, only for experimental phase')
|
|
||||||
parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
|
|
||||||
parser.set_defaults(hard_round=True)
|
|
||||||
|
|
||||||
parser.add_argument('--round_approx', type=str, default='smooth',
|
|
||||||
choices=['smooth', 'stochastic'])
|
|
||||||
|
|
||||||
parser.add_argument('--lr_decay', default=0.999, type=float,
|
|
||||||
help='Learning rate')
|
|
||||||
|
|
||||||
parser.add_argument('--temperature', default=1.0, type=float,
|
|
||||||
help='Temperature used for BackRound. It is used in '
|
|
||||||
'the the SmoothRound module. '
|
|
||||||
'(default=1.0')
|
|
||||||
|
|
||||||
# gpu/cpu
|
|
||||||
parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU',
|
|
||||||
help='choose GPU to run on.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
|
||||||
|
|
||||||
if args.manual_seed is None:
|
|
||||||
args.manual_seed = random.randint(1, 100000)
|
|
||||||
random.seed(args.manual_seed)
|
|
||||||
torch.manual_seed(args.manual_seed)
|
|
||||||
np.random.seed(args.manual_seed)
|
|
||||||
|
|
||||||
|
|
||||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
|
||||||
|
|
||||||
|
|
||||||
def run(args, kwargs):
|
|
||||||
|
|
||||||
print('\nMODEL SETTINGS: \n', args, '\n')
|
|
||||||
print("Random Seed: ", args.manual_seed)
|
|
||||||
|
|
||||||
if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5:
|
|
||||||
args.evaluate_interval_epochs = 5
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# SNAPSHOTS
|
|
||||||
# ==================================================================================================================
|
|
||||||
args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
|
|
||||||
args.model_signature = args.model_signature.replace(':', '_')
|
|
||||||
|
|
||||||
snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset)
|
|
||||||
snap_dir = snapshots_path
|
|
||||||
|
|
||||||
snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels)
|
|
||||||
|
|
||||||
snap_dir = snap_dir + '__' + args.model_signature + '/'
|
|
||||||
|
|
||||||
args.snap_dir = snap_dir
|
|
||||||
|
|
||||||
if not os.path.exists(snap_dir):
|
|
||||||
os.makedirs(snap_dir)
|
|
||||||
|
|
||||||
with open(snap_dir + 'log.txt', 'a') as ff:
|
|
||||||
print('\nMODEL SETTINGS: \n', args, '\n', file=ff)
|
|
||||||
|
|
||||||
# SAVING
|
|
||||||
torch.save(args, snap_dir + '.config')
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# LOAD DATA
|
|
||||||
# ==================================================================================================================
|
|
||||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# SELECT MODEL
|
|
||||||
# ==================================================================================================================
|
|
||||||
# flow parameters and architecture choice are passed on to model through args
|
|
||||||
print(args.input_size)
|
|
||||||
|
|
||||||
import models.Model as Model
|
|
||||||
|
|
||||||
model = Model.Model(args)
|
|
||||||
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
||||||
model.set_temperature(args.temperature)
|
|
||||||
model.enable_hard_round(args.hard_round)
|
|
||||||
|
|
||||||
model_sample = model
|
|
||||||
|
|
||||||
# ====================================
|
|
||||||
# INIT
|
|
||||||
# ====================================
|
|
||||||
# data dependend initialization on CPU
|
|
||||||
for batch_idx, (data, _) in enumerate(train_loader):
|
|
||||||
model(data)
|
|
||||||
break
|
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
||||||
model = torch.nn.DataParallel(model, dim=0)
|
|
||||||
|
|
||||||
model.to(args.DEVICE)
|
|
||||||
|
|
||||||
def lr_lambda(epoch):
|
|
||||||
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
|
||||||
optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7)
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# TRAINING
|
|
||||||
# ==================================================================================================================
|
|
||||||
train_bpd = []
|
|
||||||
val_bpd = []
|
|
||||||
|
|
||||||
# for early stopping
|
|
||||||
best_val_bpd = np.inf
|
|
||||||
best_train_bpd = np.inf
|
|
||||||
epoch = 0
|
|
||||||
|
|
||||||
train_times = []
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
for epoch in range(1, args.epochs + 1):
|
|
||||||
t_start = time.time()
|
|
||||||
scheduler.step()
|
|
||||||
tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args)
|
|
||||||
train_bpd.append(tr_bpd)
|
|
||||||
train_times.append(time.time()-t_start)
|
|
||||||
print('One training epoch took %.2f seconds' % (time.time()-t_start))
|
|
||||||
|
|
||||||
if epoch < 25 or epoch % args.evaluate_interval_epochs == 0:
|
|
||||||
v_loss, v_bpd = evaluate(
|
|
||||||
train_loader, val_loader, model, model_sample, args,
|
|
||||||
epoch=epoch, file=snap_dir + 'log.txt')
|
|
||||||
|
|
||||||
val_bpd.append(v_bpd)
|
|
||||||
|
|
||||||
# Model save based on TRAIN performance (is heavily correlated with validation performance.)
|
|
||||||
if np.mean(tr_bpd) < best_train_bpd:
|
|
||||||
best_train_bpd = np.mean(tr_bpd)
|
|
||||||
best_val_bpd = v_bpd
|
|
||||||
torch.save(model.module, snap_dir + 'a.model')
|
|
||||||
torch.save(optimizer, snap_dir + 'a.optimizer')
|
|
||||||
print('->model saved<-')
|
|
||||||
|
|
||||||
print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format(
|
|
||||||
best_train_bpd, best_val_bpd))
|
|
||||||
|
|
||||||
if math.isnan(v_loss):
|
|
||||||
raise ValueError('NaN encountered!')
|
|
||||||
|
|
||||||
train_bpd = np.hstack(train_bpd)
|
|
||||||
val_bpd = np.array(val_bpd)
|
|
||||||
|
|
||||||
# training time per epoch
|
|
||||||
train_times = np.array(train_times)
|
|
||||||
mean_train_time = np.mean(train_times)
|
|
||||||
std_train_time = np.std(train_times, ddof=1)
|
|
||||||
print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time))
|
|
||||||
|
|
||||||
# ==================================================================================================================
|
|
||||||
# EVALUATION
|
|
||||||
# ==================================================================================================================
|
|
||||||
final_model = torch.load(snap_dir + 'a.model')
|
|
||||||
test_loss, test_bpd = evaluate(
|
|
||||||
train_loader, test_loader, final_model, final_model, args,
|
|
||||||
epoch=epoch, file=snap_dir + 'test_log.txt')
|
|
||||||
|
|
||||||
print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
run(args, kwargs)
|
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -1,191 +0,0 @@
|
||||||
import torch
|
|
||||||
import models.generative_flows as generative_flows
|
|
||||||
import numpy as np
|
|
||||||
from models.utils import Base
|
|
||||||
from .priors import Prior
|
|
||||||
from optimization.loss import compute_loss_array
|
|
||||||
from coding.coder import encode_sample, decode_sample
|
|
||||||
|
|
||||||
|
|
||||||
class Normalize(Base):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.n_bits = args.n_bits
|
|
||||||
self.variable_type = args.variable_type
|
|
||||||
self.input_size = args.input_size
|
|
||||||
|
|
||||||
def forward(self, x, ldj, reverse=False):
|
|
||||||
domain = 2.**self.n_bits
|
|
||||||
|
|
||||||
if self.variable_type == 'discrete':
|
|
||||||
# Discrete variables will be measured on intervals sized 1/domain.
|
|
||||||
# Hence, there is no need to change the log Jacobian determinant.
|
|
||||||
dldj = 0
|
|
||||||
elif self.variable_type == 'continuous':
|
|
||||||
dldj = -np.log(domain) * np.prod(self.input_size)
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
if not reverse:
|
|
||||||
x = (x - domain / 2) / domain
|
|
||||||
ldj += dldj
|
|
||||||
else:
|
|
||||||
x = x * domain + domain / 2
|
|
||||||
ldj -= dldj
|
|
||||||
|
|
||||||
return x, ldj
|
|
||||||
|
|
||||||
|
|
||||||
class Model(Base):
|
|
||||||
"""
|
|
||||||
The base VAE class containing gated convolutional encoder and decoder
|
|
||||||
architecture. Can be used as a base class for VAE's with normalizing flows.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.variable_type = args.variable_type
|
|
||||||
self.distribution_type = args.distribution_type
|
|
||||||
|
|
||||||
n_channels, height, width = args.input_size
|
|
||||||
|
|
||||||
self.normalize = Normalize(args)
|
|
||||||
|
|
||||||
self.flow = generative_flows.GenerativeFlow(
|
|
||||||
n_channels, height, width, args)
|
|
||||||
|
|
||||||
self.n_bits = args.n_bits
|
|
||||||
|
|
||||||
self.z_size = self.flow.z_size
|
|
||||||
|
|
||||||
self.prior = Prior(self.z_size, args)
|
|
||||||
|
|
||||||
def dequantize(self, x):
|
|
||||||
if self.training:
|
|
||||||
x = x + torch.rand_like(x)
|
|
||||||
else:
|
|
||||||
# Required for stability.
|
|
||||||
alpha = 1e-3
|
|
||||||
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def loss(self, pz, z, pys, ys, ldj):
|
|
||||||
batchsize = z.size(0)
|
|
||||||
loss, bpd, bpd_per_prior = \
|
|
||||||
compute_loss_array(pz, z, pys, ys, ldj, self.args)
|
|
||||||
|
|
||||||
for module in self.modules():
|
|
||||||
if hasattr(module, 'auxillary_loss'):
|
|
||||||
loss += module.auxillary_loss() / batchsize
|
|
||||||
|
|
||||||
return loss, bpd, bpd_per_prior
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Evaluates the model as a whole, encodes and decodes. Note that the log
|
|
||||||
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
|
|
||||||
"""
|
|
||||||
# Decode z to x.
|
|
||||||
|
|
||||||
assert x.dtype == torch.uint8
|
|
||||||
|
|
||||||
x = x.float()
|
|
||||||
|
|
||||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
|
||||||
if self.variable_type == 'continuous':
|
|
||||||
x = self.dequantize(x)
|
|
||||||
elif self.variable_type == 'discrete':
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
x, ldj = self.normalize(x, ldj)
|
|
||||||
|
|
||||||
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
|
|
||||||
|
|
||||||
pz, z, ldj = self.prior(z, ldj)
|
|
||||||
|
|
||||||
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
|
|
||||||
|
|
||||||
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
|
|
||||||
|
|
||||||
def inverse(self, z, ys):
|
|
||||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
|
||||||
x, ldj, pys, py = \
|
|
||||||
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
|
|
||||||
|
|
||||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
|
||||||
|
|
||||||
x_uint8 = torch.clamp(x, min=0, max=255).to(
|
|
||||||
torch.uint8)
|
|
||||||
|
|
||||||
return x_uint8
|
|
||||||
|
|
||||||
def sample(self, n):
|
|
||||||
z_sample = self.prior.sample(n)
|
|
||||||
|
|
||||||
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
|
|
||||||
x_sample, ldj, pys, py = \
|
|
||||||
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
|
|
||||||
|
|
||||||
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
|
|
||||||
|
|
||||||
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
|
|
||||||
torch.uint8)
|
|
||||||
|
|
||||||
return x_sample_uint8
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
batchsize = x.size(0)
|
|
||||||
_, _, _, pz, z, pys, ys, _ = self.forward(x)
|
|
||||||
|
|
||||||
pjs = list(pys) + [pz]
|
|
||||||
js = list(ys) + [z]
|
|
||||||
|
|
||||||
states = []
|
|
||||||
|
|
||||||
for b in range(batchsize):
|
|
||||||
state = None
|
|
||||||
for pj, j in zip(pjs, js):
|
|
||||||
pj_b = [param[b:b+1] for param in pj]
|
|
||||||
j_b = j[b:b+1]
|
|
||||||
|
|
||||||
state = encode_sample(
|
|
||||||
j_b, pj_b, self.variable_type,
|
|
||||||
self.distribution_type, state=state)
|
|
||||||
if state is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
states.append(state)
|
|
||||||
|
|
||||||
return states
|
|
||||||
|
|
||||||
def decode(self, states):
|
|
||||||
def decode_fn(states, pj):
|
|
||||||
states = list(states)
|
|
||||||
j = []
|
|
||||||
|
|
||||||
for b in range(len(states)):
|
|
||||||
pj_b = [param[b:b+1] for param in pj]
|
|
||||||
|
|
||||||
states[b], j_b = decode_sample(
|
|
||||||
states[b], pj_b, self.variable_type,
|
|
||||||
self.distribution_type)
|
|
||||||
|
|
||||||
j.append(j_b)
|
|
||||||
|
|
||||||
j = torch.cat(j, dim=0)
|
|
||||||
return states, j
|
|
||||||
|
|
||||||
states, z = self.prior.decode(states, decode_fn=decode_fn)
|
|
||||||
|
|
||||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
|
||||||
|
|
||||||
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
|
|
||||||
|
|
||||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
|
||||||
x = x.to(dtype=torch.uint8)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
@ -1,151 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from models.utils import Base
|
|
||||||
|
|
||||||
|
|
||||||
class RoundStraightThrough(torch.autograd.Function):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input):
|
|
||||||
rounded = torch.round(input, out=None)
|
|
||||||
return rounded
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
grad_input = grad_output.clone()
|
|
||||||
return grad_input
|
|
||||||
|
|
||||||
|
|
||||||
_round_straightthrough = RoundStraightThrough().apply
|
|
||||||
|
|
||||||
|
|
||||||
def _stacked_sigmoid(x, temperature, n_approx=3):
|
|
||||||
|
|
||||||
x_ = x - 0.5
|
|
||||||
rounded = torch.round(x_)
|
|
||||||
x_remainder = x_ - rounded
|
|
||||||
|
|
||||||
size = x_.size()
|
|
||||||
x_remainder = x_remainder.view(size + (1,))
|
|
||||||
|
|
||||||
translation = torch.arange(n_approx) - n_approx // 2
|
|
||||||
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
|
||||||
translation = translation.view([1] * len(size) + [len(translation)])
|
|
||||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
|
||||||
|
|
||||||
return out + rounded - (n_approx // 2)
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothRound(Base):
|
|
||||||
def __init__(self):
|
|
||||||
self._temperature = None
|
|
||||||
self._n_approx = None
|
|
||||||
super().__init__()
|
|
||||||
self.hard_round = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def temperature(self):
|
|
||||||
return self._temperature
|
|
||||||
|
|
||||||
@temperature.setter
|
|
||||||
def temperature(self, value):
|
|
||||||
self._temperature = value
|
|
||||||
|
|
||||||
if self._temperature <= 0.05:
|
|
||||||
self._n_approx = 1
|
|
||||||
elif 0.05 < self._temperature < 0.13:
|
|
||||||
self._n_approx = 3
|
|
||||||
else:
|
|
||||||
self._n_approx = 5
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
assert self._temperature is not None
|
|
||||||
assert self._n_approx is not None
|
|
||||||
assert self.hard_round is not None
|
|
||||||
|
|
||||||
if self.temperature <= 0.25:
|
|
||||||
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
|
|
||||||
else:
|
|
||||||
h = x
|
|
||||||
|
|
||||||
if self.hard_round:
|
|
||||||
h = _round_straightthrough(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class StochasticRound(Base):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.hard_round = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
u = torch.rand_like(x)
|
|
||||||
|
|
||||||
h = x + u - 0.5
|
|
||||||
|
|
||||||
if self.hard_round:
|
|
||||||
h = _round_straightthrough(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class BackRound(Base):
|
|
||||||
|
|
||||||
def __init__(self, args, inverse_bin_width):
|
|
||||||
"""
|
|
||||||
BackRound is an approximation to Round that allows for Backpropagation.
|
|
||||||
|
|
||||||
Approximate the round function using a sum of translated sigmoids.
|
|
||||||
The temperature determines how well the round function is approximated,
|
|
||||||
i.e., a lower temperature corresponds to a better approximation, at
|
|
||||||
the cost of more vanishing gradients.
|
|
||||||
|
|
||||||
BackRound supports the following settings:
|
|
||||||
* By setting hard to True and temperature > 0.25, BackRound
|
|
||||||
reduces to a round function with a straight through gradient
|
|
||||||
estimator
|
|
||||||
* When using 0 < temperature <= 0.25 and hard = True, the
|
|
||||||
output in the forward pass is equivalent to a round function, but the
|
|
||||||
gradient is approximated by the gradient of a sum of sigmoids.
|
|
||||||
* When using hard = False, the output is not constrained to integers.
|
|
||||||
* When temperature > 0.25 and hard = False, BackRound reduces to
|
|
||||||
the identity function.
|
|
||||||
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
temperature: float
|
|
||||||
Temperature used for stacked sigmoid approximated. If temperature
|
|
||||||
is greater than 0.25, the approximation reduces to the indentiy
|
|
||||||
function.
|
|
||||||
hard: bool
|
|
||||||
If hard is True, a (hard) round is applied before returning. The
|
|
||||||
gradient for this is approximated using the straight-through
|
|
||||||
estimator.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.inverse_bin_width = inverse_bin_width
|
|
||||||
self.round_approx = args.round_approx
|
|
||||||
|
|
||||||
if args.round_approx == 'smooth':
|
|
||||||
self.round = SmoothRound()
|
|
||||||
elif args.round_approx == 'stochastic':
|
|
||||||
self.round = StochasticRound()
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
|
|
||||||
h = x * self.inverse_bin_width
|
|
||||||
|
|
||||||
h = self.round(h)
|
|
||||||
|
|
||||||
return h / self.inverse_bin_width
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
"""
|
|
||||||
Collection of flow strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from models.utils import Base
|
|
||||||
from .backround import BackRound
|
|
||||||
from .networks import NN
|
|
||||||
|
|
||||||
|
|
||||||
UNIT_TESTING = False
|
|
||||||
|
|
||||||
|
|
||||||
class SplitFactorCoupling(Base):
|
|
||||||
def __init__(self, c_in, factor, height, width, args):
|
|
||||||
super().__init__()
|
|
||||||
self.n_channels = args.n_channels
|
|
||||||
self.kernel = 3
|
|
||||||
self.input_channel = c_in
|
|
||||||
self.round_approx = args.round_approx
|
|
||||||
|
|
||||||
if args.variable_type == 'discrete':
|
|
||||||
self.round = BackRound(
|
|
||||||
args, inverse_bin_width=2**args.n_bits)
|
|
||||||
else:
|
|
||||||
self.round = None
|
|
||||||
|
|
||||||
self.split_idx = c_in - (c_in // factor)
|
|
||||||
|
|
||||||
self.nn = NN(
|
|
||||||
args=args,
|
|
||||||
c_in=self.split_idx,
|
|
||||||
c_out=c_in - self.split_idx,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
kernel=self.kernel,
|
|
||||||
nn_type=args.coupling_type)
|
|
||||||
|
|
||||||
def forward(self, z, ldj, reverse=False):
|
|
||||||
z1 = z[:, :self.split_idx, :, :]
|
|
||||||
z2 = z[:, self.split_idx:, :, :]
|
|
||||||
|
|
||||||
t = self.nn(z1)
|
|
||||||
|
|
||||||
if self.round is not None:
|
|
||||||
t = self.round(t)
|
|
||||||
|
|
||||||
if not reverse:
|
|
||||||
z2 = z2 + t
|
|
||||||
else:
|
|
||||||
z2 = z2 - t
|
|
||||||
|
|
||||||
z = torch.cat([z1, z2], dim=1)
|
|
||||||
|
|
||||||
return z, ldj
|
|
||||||
|
|
||||||
|
|
||||||
class Coupling(Base):
|
|
||||||
def __init__(self, c_in, height, width, args):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if args.split_quarter:
|
|
||||||
factor = 4
|
|
||||||
elif args.splitfactor > 1:
|
|
||||||
factor = args.splitfactor
|
|
||||||
else:
|
|
||||||
factor = 2
|
|
||||||
|
|
||||||
self.coupling = SplitFactorCoupling(
|
|
||||||
c_in, factor, height, width, args=args)
|
|
||||||
|
|
||||||
def forward(self, z, ldj, reverse=False):
|
|
||||||
return self.coupling(z, ldj, reverse)
|
|
||||||
|
|
||||||
|
|
||||||
def test_generative_flow():
|
|
||||||
import models.networks as networks
|
|
||||||
global UNIT_TESTING
|
|
||||||
|
|
||||||
networks.UNIT_TESTING = True
|
|
||||||
UNIT_TESTING = True
|
|
||||||
|
|
||||||
batch_size = 17
|
|
||||||
|
|
||||||
input_size = [12, 16, 16]
|
|
||||||
|
|
||||||
class Args():
|
|
||||||
def __init__(self):
|
|
||||||
self.input_size = input_size
|
|
||||||
self.learn_split = False
|
|
||||||
self.variable_type = 'continuous'
|
|
||||||
self.distribution_type = 'logistic'
|
|
||||||
self.round_approx = 'smooth'
|
|
||||||
self.coupling_type = 'shallow'
|
|
||||||
self.conv_type = 'standard'
|
|
||||||
self.densenet_depth = 8
|
|
||||||
self.bottleneck = False
|
|
||||||
self.n_channels = 512
|
|
||||||
self.network1x1 = 'standard'
|
|
||||||
self.auxilary_freq = -1
|
|
||||||
self.actnorm = False
|
|
||||||
self.LU = False
|
|
||||||
self.coupling_lifting_L = True
|
|
||||||
self.splitprior = True
|
|
||||||
self.split_quarter = True
|
|
||||||
self.n_levels = 2
|
|
||||||
self.n_flows = 2
|
|
||||||
self.cond_L = True
|
|
||||||
self.n_bits = True
|
|
||||||
|
|
||||||
args = Args()
|
|
||||||
|
|
||||||
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
|
|
||||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
|
||||||
|
|
||||||
model = Coupling(c_in=12, height=16, width=16, args=args)
|
|
||||||
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
model.set_temperature(1.)
|
|
||||||
model.enable_hard_round()
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
z, ldj = model(x, ldj, reverse=False)
|
|
||||||
|
|
||||||
# Check if gradient computation works
|
|
||||||
loss = torch.sum(z**2)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
recon, ldj = model(z, ldj, reverse=True)
|
|
||||||
|
|
||||||
sse = torch.sum(torch.pow(x - recon, 2)).item()
|
|
||||||
ae = torch.abs(x - recon).sum()
|
|
||||||
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_generative_flow()
|
|
||||||
|
|
@ -1,175 +0,0 @@
|
||||||
"""
|
|
||||||
Collection of flow strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from models.utils import Base
|
|
||||||
from .priors import SplitPrior
|
|
||||||
from .coupling import Coupling
|
|
||||||
|
|
||||||
|
|
||||||
UNIT_TESTING = False
|
|
||||||
|
|
||||||
|
|
||||||
def space_to_depth(x):
|
|
||||||
xs = x.size()
|
|
||||||
# Pick off every second element
|
|
||||||
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
|
||||||
# Transpose picked elements next to channels.
|
|
||||||
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
|
||||||
# Combine with channels.
|
|
||||||
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def depth_to_space(x):
|
|
||||||
xs = x.size()
|
|
||||||
# Pick off elements from channels
|
|
||||||
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
|
||||||
# Transpose picked elements next to HW dimensions.
|
|
||||||
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
|
||||||
# Combine with HW dimensions.
|
|
||||||
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def int_shape(x):
|
|
||||||
return list(map(int, x.size()))
|
|
||||||
|
|
||||||
|
|
||||||
class Flatten(Base):
|
|
||||||
def forward(self, x):
|
|
||||||
return x.view(x.size(0), -1)
|
|
||||||
|
|
||||||
|
|
||||||
class Reshape(Base):
|
|
||||||
def __init__(self, shape):
|
|
||||||
super().__init__()
|
|
||||||
self.shape = shape
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x.view(x.size(0), *self.shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Reverse(Base):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, z, reverse=False):
|
|
||||||
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
|
||||||
z = z[:, flip_idx, :, :]
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class Permute(Base):
|
|
||||||
def __init__(self, n_channels):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
permutation = np.arange(n_channels, dtype='int')
|
|
||||||
np.random.shuffle(permutation)
|
|
||||||
|
|
||||||
permutation_inv = np.zeros(n_channels, dtype='int')
|
|
||||||
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
|
||||||
|
|
||||||
self.permutation = torch.from_numpy(permutation)
|
|
||||||
self.permutation_inv = torch.from_numpy(permutation_inv)
|
|
||||||
|
|
||||||
def forward(self, z, ldj, reverse=False):
|
|
||||||
if not reverse:
|
|
||||||
z = z[:, self.permutation, :, :]
|
|
||||||
else:
|
|
||||||
z = z[:, self.permutation_inv, :, :]
|
|
||||||
|
|
||||||
return z, ldj
|
|
||||||
|
|
||||||
def InversePermute(self):
|
|
||||||
inv_permute = Permute(len(self.permutation))
|
|
||||||
inv_permute.permutation = self.permutation_inv
|
|
||||||
inv_permute.permutation_inv = self.permutation
|
|
||||||
return inv_permute
|
|
||||||
|
|
||||||
|
|
||||||
class Squeeze(Base):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, z, ldj, reverse=False):
|
|
||||||
if not reverse:
|
|
||||||
z = space_to_depth(z)
|
|
||||||
else:
|
|
||||||
z = depth_to_space(z)
|
|
||||||
return z, ldj
|
|
||||||
|
|
||||||
|
|
||||||
class GenerativeFlow(Base):
|
|
||||||
def __init__(self, n_channels, height, width, args):
|
|
||||||
super().__init__()
|
|
||||||
layers = []
|
|
||||||
layers.append(Squeeze())
|
|
||||||
n_channels *= 4
|
|
||||||
height //= 2
|
|
||||||
width //= 2
|
|
||||||
|
|
||||||
for level in range(args.n_levels):
|
|
||||||
|
|
||||||
for i in range(args.n_flows):
|
|
||||||
perm_layer = Permute(n_channels)
|
|
||||||
layers.append(perm_layer)
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
Coupling(n_channels, height, width, args))
|
|
||||||
|
|
||||||
if level < args.n_levels - 1:
|
|
||||||
if args.splitprior_type != 'none':
|
|
||||||
# Standard splitprior
|
|
||||||
factor_out = n_channels // 2
|
|
||||||
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
|
||||||
n_channels = n_channels - factor_out
|
|
||||||
|
|
||||||
layers.append(Squeeze())
|
|
||||||
n_channels *= 4
|
|
||||||
height //= 2
|
|
||||||
width //= 2
|
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(layers)
|
|
||||||
self.z_size = (n_channels, height, width)
|
|
||||||
|
|
||||||
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
|
||||||
if not reverse:
|
|
||||||
for l, layer in enumerate(self.layers):
|
|
||||||
if isinstance(layer, (SplitPrior)):
|
|
||||||
py, y, z, ldj = layer(z, ldj)
|
|
||||||
pys += (py,)
|
|
||||||
ys += (y,)
|
|
||||||
|
|
||||||
else:
|
|
||||||
z, ldj = layer(z, ldj)
|
|
||||||
|
|
||||||
else:
|
|
||||||
for l, layer in reversed(list(enumerate(self.layers))):
|
|
||||||
if isinstance(layer, (SplitPrior)):
|
|
||||||
if len(ys) > 0:
|
|
||||||
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
|
||||||
# Pop last element
|
|
||||||
ys = ys[:-1]
|
|
||||||
else:
|
|
||||||
z, ldj = layer.inverse(z, ldj, y=None)
|
|
||||||
|
|
||||||
else:
|
|
||||||
z, ldj = layer(z, ldj, reverse=True)
|
|
||||||
|
|
||||||
return z, ldj, pys, ys
|
|
||||||
|
|
||||||
def decode(self, z, ldj, state, decode_fn):
|
|
||||||
|
|
||||||
for l, layer in reversed(list(enumerate(self.layers))):
|
|
||||||
if isinstance(layer, SplitPrior):
|
|
||||||
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
|
||||||
|
|
||||||
else:
|
|
||||||
z, ldj = layer(z, ldj, reverse=True)
|
|
||||||
|
|
||||||
return z, ldj
|
|
||||||
|
|
@ -1,154 +0,0 @@
|
||||||
"""
|
|
||||||
Collection of flow strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from models.utils import Base
|
|
||||||
|
|
||||||
|
|
||||||
UNIT_TESTING = False
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dReLU(Base):
|
|
||||||
def __init__(
|
|
||||||
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
|
|
||||||
bias=True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h = self.nn(x)
|
|
||||||
|
|
||||||
y = F.relu(h)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(Base):
|
|
||||||
def __init__(self, n_channels, kernel, Conv2dAct):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.nn = torch.nn.Sequential(
|
|
||||||
Conv2dAct(n_channels, n_channels, kernel, padding=1),
|
|
||||||
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h = self.nn(x)
|
|
||||||
h = F.relu(h + x)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class DenseLayer(Base):
|
|
||||||
def __init__(self, args, n_inputs, growth, Conv2dAct):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
conv1x1 = Conv2dAct(
|
|
||||||
n_inputs, n_inputs, kernel_size=1, stride=1,
|
|
||||||
padding=0, bias=True)
|
|
||||||
|
|
||||||
self.nn = torch.nn.Sequential(
|
|
||||||
conv1x1,
|
|
||||||
Conv2dAct(
|
|
||||||
n_inputs, growth, kernel_size=3, stride=1,
|
|
||||||
padding=1, bias=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h = self.nn(x)
|
|
||||||
|
|
||||||
h = torch.cat([x, h], dim=1)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class DenseBlock(Base):
|
|
||||||
def __init__(
|
|
||||||
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
|
|
||||||
super().__init__()
|
|
||||||
depth = args.densenet_depth
|
|
||||||
|
|
||||||
future_growth = n_outputs - n_inputs
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
for d in range(depth):
|
|
||||||
growth = future_growth // (depth - d)
|
|
||||||
|
|
||||||
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
|
|
||||||
n_inputs += growth
|
|
||||||
future_growth -= growth
|
|
||||||
|
|
||||||
self.nn = torch.nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.nn(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(Base):
|
|
||||||
def __init__(self):
|
|
||||||
super.__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class NN(Base):
|
|
||||||
def __init__(
|
|
||||||
self, args, c_in, c_out, height, width, nn_type, kernel=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
Conv2dAct = Conv2dReLU
|
|
||||||
n_channels = args.n_channels
|
|
||||||
|
|
||||||
if nn_type == 'shallow':
|
|
||||||
|
|
||||||
if args.network1x1 == 'standard':
|
|
||||||
conv1x1 = Conv2dAct(
|
|
||||||
n_channels, n_channels, kernel_size=1, stride=1,
|
|
||||||
padding=0, bias=False)
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
|
||||||
conv1x1]
|
|
||||||
|
|
||||||
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
|
|
||||||
|
|
||||||
elif nn_type == 'resnet':
|
|
||||||
layers = [
|
|
||||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
|
||||||
ResidualBlock(n_channels, kernel, Conv2dAct),
|
|
||||||
ResidualBlock(n_channels, kernel, Conv2dAct)]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
|
|
||||||
]
|
|
||||||
|
|
||||||
elif nn_type == 'densenet':
|
|
||||||
layers = [
|
|
||||||
DenseBlock(
|
|
||||||
args=args,
|
|
||||||
n_inputs=c_in,
|
|
||||||
n_outputs=n_channels + c_in,
|
|
||||||
kernel=kernel,
|
|
||||||
Conv2dAct=Conv2dAct)]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
self.nn = torch.nn.Sequential(*layers)
|
|
||||||
|
|
||||||
# Set parameters of last conv-layer to zero.
|
|
||||||
if not UNIT_TESTING:
|
|
||||||
self.nn[-1].weight.data.zero_()
|
|
||||||
self.nn[-1].bias.data.zero_()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.nn(x)
|
|
||||||
|
|
@ -1,164 +0,0 @@
|
||||||
"""
|
|
||||||
Collection of flow strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn import Parameter
|
|
||||||
from utils.distributions import sample_discretized_logistic, \
|
|
||||||
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
|
|
||||||
sample_discretized_normal, sample_mixture_normal
|
|
||||||
from models.utils import Base
|
|
||||||
from .networks import NN
|
|
||||||
|
|
||||||
|
|
||||||
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
|
|
||||||
if variable_type == 'discrete':
|
|
||||||
if distribution_type == 'logistic':
|
|
||||||
if len(px) == 2:
|
|
||||||
return sample_discretized_logistic(
|
|
||||||
*px, inverse_bin_width=inverse_bin_width)
|
|
||||||
elif len(px) == 3:
|
|
||||||
return sample_mixture_discretized_logistic(
|
|
||||||
*px, inverse_bin_width=inverse_bin_width)
|
|
||||||
|
|
||||||
elif distribution_type == 'normal':
|
|
||||||
return sample_discretized_normal(
|
|
||||||
*px, inverse_bin_width=inverse_bin_width)
|
|
||||||
|
|
||||||
elif variable_type == 'continuous':
|
|
||||||
if distribution_type == 'logistic':
|
|
||||||
return sample_logistic(*px)
|
|
||||||
elif distribution_type == 'normal':
|
|
||||||
if len(px) == 2:
|
|
||||||
return sample_normal(*px)
|
|
||||||
elif len(px) == 3:
|
|
||||||
return sample_mixture_normal(*px)
|
|
||||||
elif distribution_type == 'steplogistic':
|
|
||||||
return sample_logistic(*px)
|
|
||||||
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
class Prior(Base):
|
|
||||||
def __init__(self, size, args):
|
|
||||||
super().__init__()
|
|
||||||
c, h, w = size
|
|
||||||
|
|
||||||
self.inverse_bin_width = 2**args.n_bits
|
|
||||||
self.variable_type = args.variable_type
|
|
||||||
self.distribution_type = args.distribution_type
|
|
||||||
self.n_mixtures = args.n_mixtures
|
|
||||||
|
|
||||||
if self.n_mixtures == 1:
|
|
||||||
self.mu = Parameter(torch.Tensor(c, h, w))
|
|
||||||
self.logs = Parameter(torch.Tensor(c, h, w))
|
|
||||||
elif self.n_mixtures > 1:
|
|
||||||
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
||||||
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
||||||
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
self.mu.data.zero_()
|
|
||||||
|
|
||||||
if self.n_mixtures > 1:
|
|
||||||
self.pi_logit.data.zero_()
|
|
||||||
for i in range(self.n_mixtures):
|
|
||||||
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
|
|
||||||
|
|
||||||
self.logs.data.zero_()
|
|
||||||
|
|
||||||
def get_pz(self, n):
|
|
||||||
if self.n_mixtures == 1:
|
|
||||||
mu = self.mu.repeat(n, 1, 1, 1)
|
|
||||||
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
|
|
||||||
return mu, logs
|
|
||||||
|
|
||||||
elif self.n_mixtures > 1:
|
|
||||||
pi = F.softmax(self.pi_logit, dim=-1)
|
|
||||||
mu = self.mu.repeat(n, 1, 1, 1, 1)
|
|
||||||
logs = self.logs.repeat(n, 1, 1, 1, 1)
|
|
||||||
pi = pi.repeat(n, 1, 1, 1, 1)
|
|
||||||
return mu, logs, pi
|
|
||||||
|
|
||||||
def forward(self, z, ldj):
|
|
||||||
pz = self.get_pz(z.size(0))
|
|
||||||
|
|
||||||
return pz, z, ldj
|
|
||||||
|
|
||||||
def sample(self, n):
|
|
||||||
pz = self.get_pz(n)
|
|
||||||
|
|
||||||
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
|
||||||
|
|
||||||
return z_sample
|
|
||||||
|
|
||||||
def decode(self, states, decode_fn):
|
|
||||||
pz = self.get_pz(n=len(states))
|
|
||||||
|
|
||||||
states, z = decode_fn(states, pz)
|
|
||||||
return states, z
|
|
||||||
|
|
||||||
|
|
||||||
class SplitPrior(Base):
|
|
||||||
def __init__(self, c_in, factor_out, height, width, args):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.split_idx = c_in - factor_out
|
|
||||||
self.inverse_bin_width = 2**args.n_bits
|
|
||||||
self.variable_type = args.variable_type
|
|
||||||
self.distribution_type = args.distribution_type
|
|
||||||
self.input_channel = c_in
|
|
||||||
|
|
||||||
self.nn = NN(
|
|
||||||
args=args,
|
|
||||||
c_in=c_in - factor_out,
|
|
||||||
c_out=factor_out * 2,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
nn_type=args.splitprior_type)
|
|
||||||
|
|
||||||
def get_py(self, z):
|
|
||||||
h = self.nn(z)
|
|
||||||
mu = h[:, ::2, :, :]
|
|
||||||
logs = h[:, 1::2, :, :]
|
|
||||||
|
|
||||||
py = [mu, logs]
|
|
||||||
|
|
||||||
return py
|
|
||||||
|
|
||||||
def split(self, z):
|
|
||||||
z1 = z[:, :self.split_idx, :, :]
|
|
||||||
y = z[:, self.split_idx:, :, :]
|
|
||||||
return z1, y
|
|
||||||
|
|
||||||
def combine(self, z, y):
|
|
||||||
result = torch.cat([z, y], dim=1)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def forward(self, z, ldj):
|
|
||||||
z, y = self.split(z)
|
|
||||||
|
|
||||||
py = self.get_py(z)
|
|
||||||
|
|
||||||
return py, y, z, ldj
|
|
||||||
|
|
||||||
def inverse(self, z, ldj, y):
|
|
||||||
# Sample if y is not given.
|
|
||||||
if y is None:
|
|
||||||
py = self.get_py(z)
|
|
||||||
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
|
||||||
|
|
||||||
z = self.combine(z, y)
|
|
||||||
|
|
||||||
return z, ldj
|
|
||||||
|
|
||||||
def decode(self, z, ldj, states, decode_fn):
|
|
||||||
py = self.get_py(z)
|
|
||||||
states, y = decode_fn(states, py)
|
|
||||||
return self.combine(z, y), ldj, states
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class Base(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
The base class for modules. That contains a disable round mode
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def _set_child_attribute(self, attr, value):
|
|
||||||
r"""Sets the module in rounding mode.
|
|
||||||
|
|
||||||
This has any effect only on certain modules if variable type is
|
|
||||||
discrete.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Module: self
|
|
||||||
"""
|
|
||||||
if hasattr(self, attr):
|
|
||||||
setattr(self, attr, value)
|
|
||||||
|
|
||||||
for module in self.modules():
|
|
||||||
if hasattr(module, attr):
|
|
||||||
setattr(module, attr, value)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_temperature(self, value):
|
|
||||||
self._set_child_attribute("temperature", value)
|
|
||||||
|
|
||||||
def enable_hard_round(self, mode=True):
|
|
||||||
self._set_child_attribute("hard_round", mode)
|
|
||||||
|
|
||||||
def disable_hard_round(self, mode=True):
|
|
||||||
self.enable_hard_round(not mode)
|
|
||||||
|
|
@ -1,148 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from utils.distributions import log_discretized_logistic, \
|
|
||||||
log_mixture_discretized_logistic, log_normal, log_discretized_normal, \
|
|
||||||
log_logistic, log_mixture_normal
|
|
||||||
from models.backround import _round_straightthrough
|
|
||||||
|
|
||||||
|
|
||||||
def compute_log_ps(pxs, xs, args):
|
|
||||||
# Add likelihoods of intermediate representations.
|
|
||||||
inverse_bin_width = 2.**args.n_bits
|
|
||||||
|
|
||||||
log_pxs = []
|
|
||||||
for px, x in zip(pxs, xs):
|
|
||||||
|
|
||||||
if args.variable_type == 'discrete':
|
|
||||||
if args.distribution_type == 'logistic':
|
|
||||||
log_px = log_discretized_logistic(
|
|
||||||
x, *px, inverse_bin_width=inverse_bin_width)
|
|
||||||
elif args.distribution_type == 'normal':
|
|
||||||
log_px = log_discretized_normal(
|
|
||||||
x, *px, inverse_bin_width=inverse_bin_width)
|
|
||||||
elif args.variable_type == 'continuous':
|
|
||||||
if args.distribution_type == 'logistic':
|
|
||||||
log_px = log_logistic(x, *px)
|
|
||||||
elif args.distribution_type == 'normal':
|
|
||||||
log_px = log_normal(x, *px)
|
|
||||||
elif args.distribution_type == 'steplogistic':
|
|
||||||
x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width
|
|
||||||
log_px = log_discretized_logistic(
|
|
||||||
x, *px, inverse_bin_width=inverse_bin_width)
|
|
||||||
|
|
||||||
log_pxs.append(
|
|
||||||
torch.sum(log_px, dim=[1, 2, 3]))
|
|
||||||
|
|
||||||
return log_pxs
|
|
||||||
|
|
||||||
|
|
||||||
def compute_log_pz(pz, z, args):
|
|
||||||
inverse_bin_width = 2.**args.n_bits
|
|
||||||
|
|
||||||
if args.variable_type == 'discrete':
|
|
||||||
if args.distribution_type == 'logistic':
|
|
||||||
if args.n_mixtures == 1:
|
|
||||||
log_pz = log_discretized_logistic(
|
|
||||||
z, pz[0], pz[1], inverse_bin_width=inverse_bin_width)
|
|
||||||
else:
|
|
||||||
log_pz = log_mixture_discretized_logistic(
|
|
||||||
z, pz[0], pz[1], pz[2],
|
|
||||||
inverse_bin_width=inverse_bin_width)
|
|
||||||
elif args.distribution_type == 'normal':
|
|
||||||
log_pz = log_discretized_normal(
|
|
||||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
|
||||||
|
|
||||||
elif args.variable_type == 'continuous':
|
|
||||||
if args.distribution_type == 'logistic':
|
|
||||||
log_pz = log_logistic(z, *pz)
|
|
||||||
elif args.distribution_type == 'normal':
|
|
||||||
if args.n_mixtures == 1:
|
|
||||||
log_pz = log_normal(z, *pz)
|
|
||||||
else:
|
|
||||||
log_pz = log_mixture_normal(z, *pz)
|
|
||||||
elif args.distribution_type == 'steplogistic':
|
|
||||||
z = _round_straightthrough(z * 256.) / 256.
|
|
||||||
log_pz = log_discretized_logistic(z, *pz)
|
|
||||||
|
|
||||||
log_pz = torch.sum(
|
|
||||||
log_pz,
|
|
||||||
dim=[1, 2, 3])
|
|
||||||
|
|
||||||
return log_pz
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss_function(pz, z, pys, ys, ldj, args):
|
|
||||||
"""
|
|
||||||
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
|
||||||
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
|
||||||
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
|
||||||
:param z_mu: mean of z_0
|
|
||||||
:param z_var: variance of z_0
|
|
||||||
:param z_0: first stochastic latent variable
|
|
||||||
:param z_k: last stochastic latent variable
|
|
||||||
:param ldj: log det jacobian
|
|
||||||
:param args: global parameter settings
|
|
||||||
:param beta: beta for kl loss
|
|
||||||
:return: loss, ce, kl
|
|
||||||
"""
|
|
||||||
batch_size = z.size(0)
|
|
||||||
|
|
||||||
# Get array loss, sum over batch
|
|
||||||
loss_array, bpd_array, bpd_per_prior_array = \
|
|
||||||
compute_loss_array(pz, z, pys, ys, ldj, args)
|
|
||||||
|
|
||||||
loss = torch.mean(loss_array)
|
|
||||||
bpd = torch.mean(bpd_array).item()
|
|
||||||
bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array]
|
|
||||||
|
|
||||||
return loss, bpd, bpd_per_prior
|
|
||||||
|
|
||||||
|
|
||||||
def convert_bpd(log_p, input_size):
|
|
||||||
return -log_p / (np.prod(input_size) * np.log(2.))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss_array(pz, z, pys, ys, ldj, args):
|
|
||||||
"""
|
|
||||||
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
|
||||||
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
|
||||||
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
|
||||||
:param z_mu: mean of z_0
|
|
||||||
:param z_var: variance of z_0
|
|
||||||
:param z_0: first stochastic latent variable
|
|
||||||
:param z_k: last stochastic latent variable
|
|
||||||
:param ldj: log det jacobian
|
|
||||||
:param args: global parameter settings
|
|
||||||
:param beta: beta for kl loss
|
|
||||||
:return: loss, ce, kl
|
|
||||||
"""
|
|
||||||
bpd_per_prior = []
|
|
||||||
|
|
||||||
# Likelihood of final representation.
|
|
||||||
log_pz = compute_log_pz(pz, z, args)
|
|
||||||
|
|
||||||
bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size))
|
|
||||||
|
|
||||||
log_p = log_pz
|
|
||||||
|
|
||||||
# Add likelihoods of intermediate representations.
|
|
||||||
if ys:
|
|
||||||
log_pys = compute_log_ps(pys, ys, args)
|
|
||||||
|
|
||||||
for log_py in log_pys:
|
|
||||||
log_p += log_py
|
|
||||||
|
|
||||||
bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size))
|
|
||||||
|
|
||||||
log_p += ldj
|
|
||||||
|
|
||||||
loss = -log_p
|
|
||||||
bpd = convert_bpd(log_p.detach(), args.input_size)
|
|
||||||
|
|
||||||
return loss, bpd, bpd_per_prior
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args):
|
|
||||||
return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)
|
|
||||||
|
|
@ -1,174 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from optimization.loss import calculate_loss
|
|
||||||
from utils.visual_evaluation import plot_reconstructions
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, train_loader, model, opt, args):
|
|
||||||
model.train()
|
|
||||||
train_loss = np.zeros(len(train_loader))
|
|
||||||
train_bpd = np.zeros(len(train_loader))
|
|
||||||
|
|
||||||
num_data = 0
|
|
||||||
|
|
||||||
for batch_idx, (data, _) in enumerate(train_loader):
|
|
||||||
data = data.view(-1, *args.input_size)
|
|
||||||
|
|
||||||
data = data.to(args.DEVICE)
|
|
||||||
|
|
||||||
opt.zero_grad()
|
|
||||||
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
|
||||||
|
|
||||||
loss = torch.mean(loss)
|
|
||||||
bpd = torch.mean(bpd)
|
|
||||||
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
loss = loss.item()
|
|
||||||
train_loss[batch_idx] = loss
|
|
||||||
train_bpd[batch_idx] = bpd
|
|
||||||
|
|
||||||
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
|
|
||||||
|
|
||||||
opt.step()
|
|
||||||
|
|
||||||
num_data += len(data)
|
|
||||||
|
|
||||||
if batch_idx % args.log_interval == 0:
|
|
||||||
perc = 100. * batch_idx / len(train_loader)
|
|
||||||
|
|
||||||
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
|
|
||||||
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
|
|
||||||
|
|
||||||
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
|
|
||||||
|
|
||||||
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
|
|
||||||
for i in range(1, len(bpd_per_prior)):
|
|
||||||
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
|
|
||||||
|
|
||||||
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
|
||||||
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
|
||||||
if len(pz) == 3:
|
|
||||||
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
|
||||||
|
|
||||||
for i, py in enumerate(pys):
|
|
||||||
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
|
||||||
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
|
||||||
|
|
||||||
from utils.visual_evaluation import plot_images
|
|
||||||
import os
|
|
||||||
if not os.path.exists(args.snap_dir + 'training/'):
|
|
||||||
os.makedirs(args.snap_dir + 'training/')
|
|
||||||
|
|
||||||
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
|
|
||||||
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
|
|
||||||
|
|
||||||
return train_loss, train_bpd
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loss_type = 'bpd'
|
|
||||||
|
|
||||||
def analyse(data_loader, plot=False):
|
|
||||||
bpds = []
|
|
||||||
batch_idx = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for data, _ in data_loader:
|
|
||||||
batch_idx += 1
|
|
||||||
|
|
||||||
if args.cuda:
|
|
||||||
data = data.cuda()
|
|
||||||
|
|
||||||
data = data.view(-1, *args.input_size)
|
|
||||||
|
|
||||||
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
|
|
||||||
model(data)
|
|
||||||
loss = torch.mean(loss).item()
|
|
||||||
batch_bpd = torch.mean(batch_bpd).item()
|
|
||||||
|
|
||||||
bpds.append(batch_bpd)
|
|
||||||
|
|
||||||
bpd = np.mean(bpds)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if not testing and plot:
|
|
||||||
x_sample = model_sample.sample(n=100)
|
|
||||||
|
|
||||||
try:
|
|
||||||
plot_reconstructions(
|
|
||||||
x_sample, bpd, loss_type, epoch, args)
|
|
||||||
except:
|
|
||||||
print('Not plotting')
|
|
||||||
|
|
||||||
return bpd
|
|
||||||
|
|
||||||
bpd_train = analyse(train_loader)
|
|
||||||
bpd_val = analyse(val_loader, plot=True)
|
|
||||||
|
|
||||||
with open(file, 'a') as ff:
|
|
||||||
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
|
|
||||||
epoch,
|
|
||||||
bpd_train,
|
|
||||||
bpd_val)
|
|
||||||
print(msg, file=ff)
|
|
||||||
|
|
||||||
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
|
|
||||||
bpd = bpd_val
|
|
||||||
|
|
||||||
file = None
|
|
||||||
|
|
||||||
# Compute log-likelihood
|
|
||||||
with torch.no_grad():
|
|
||||||
if testing:
|
|
||||||
test_data = val_loader.dataset.data_tensor
|
|
||||||
|
|
||||||
if args.cuda:
|
|
||||||
test_data = test_data.cuda()
|
|
||||||
|
|
||||||
print('Computing log-likelihood on test set')
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
log_likelihood = analyse(test_data)
|
|
||||||
|
|
||||||
else:
|
|
||||||
log_likelihood = None
|
|
||||||
nll_bpd = None
|
|
||||||
|
|
||||||
if file is None:
|
|
||||||
if testing:
|
|
||||||
print('====> Test set loss: {:.4f}'.format(loss))
|
|
||||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
|
|
||||||
|
|
||||||
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
|
|
||||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
|
|
||||||
(np.prod(args.input_size) * np.log(2.))))
|
|
||||||
|
|
||||||
else:
|
|
||||||
print('====> Validation set loss: {:.4f}'.format(loss))
|
|
||||||
print('====> Validation set bpd: {:.4f}'.format(bpd))
|
|
||||||
else:
|
|
||||||
with open(file, 'a') as ff:
|
|
||||||
if testing:
|
|
||||||
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
|
|
||||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
|
|
||||||
|
|
||||||
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
|
|
||||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
|
|
||||||
(np.prod(args.input_size) * np.log(2.))),
|
|
||||||
file=ff)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
|
|
||||||
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
|
|
||||||
file=ff)
|
|
||||||
|
|
||||||
if not testing:
|
|
||||||
return loss, bpd
|
|
||||||
else:
|
|
||||||
return log_likelihood, nll_bpd
|
|
||||||
|
|
@ -1,209 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
|
|
||||||
MIN_EPSILON = 1e-5
|
|
||||||
MAX_EPSILON = 1.-1e-5
|
|
||||||
|
|
||||||
|
|
||||||
PI = math.pi
|
|
||||||
|
|
||||||
|
|
||||||
def log_min_exp(a, b, epsilon=1e-8):
|
|
||||||
"""
|
|
||||||
Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.
|
|
||||||
Using:
|
|
||||||
log(exp(a) - exp(b))
|
|
||||||
c + log(exp(a-c) - exp(b-c))
|
|
||||||
a + log(1 - exp(b-a))
|
|
||||||
And note that we assume b < a always.
|
|
||||||
"""
|
|
||||||
y = a + torch.log(1 - torch.exp(b - a) + epsilon)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def log_normal(x, mean, logvar):
|
|
||||||
logp = -0.5 * logvar
|
|
||||||
logp += -0.5 * np.log(2 * PI)
|
|
||||||
logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar)
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def log_mixture_normal(x, mean, logvar, pi):
|
|
||||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
|
||||||
|
|
||||||
logp_mixtures = log_normal(x, mean, logvar)
|
|
||||||
|
|
||||||
logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8)
|
|
||||||
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def sample_normal(mean, logvar):
|
|
||||||
y = torch.randn_like(mean)
|
|
||||||
|
|
||||||
x = torch.exp(0.5 * logvar) * y + mean
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def sample_mixture_normal(mean, logvar, pi):
|
|
||||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
|
||||||
pi = pi.view(b * c * h * w, n_mixtures)
|
|
||||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
|
||||||
|
|
||||||
# Select mixture params
|
|
||||||
mean = mean.view(b * c * h * w, n_mixtures)
|
|
||||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
|
||||||
logvar = logvar.view(b * c * h * w, n_mixtures)
|
|
||||||
logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
|
||||||
|
|
||||||
y = sample_normal(mean, logvar)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def log_logistic(x, mean, logscale):
|
|
||||||
"""
|
|
||||||
pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale
|
|
||||||
"""
|
|
||||||
scale = torch.exp(logscale)
|
|
||||||
|
|
||||||
u = (x - mean) / scale
|
|
||||||
|
|
||||||
logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale
|
|
||||||
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def sample_logistic(mean, logscale):
|
|
||||||
y = torch.rand_like(mean)
|
|
||||||
|
|
||||||
x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
|
|
||||||
scale = torch.exp(logscale)
|
|
||||||
|
|
||||||
logp = log_min_exp(
|
|
||||||
F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
|
|
||||||
F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
|
|
||||||
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width):
|
|
||||||
scale = torch.exp(logscale)
|
|
||||||
|
|
||||||
cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
|
||||||
|
|
||||||
return cdf
|
|
||||||
|
|
||||||
|
|
||||||
def sample_discretized_logistic(mean, logscale, inverse_bin_width):
|
|
||||||
x = sample_logistic(mean, logscale)
|
|
||||||
|
|
||||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def normal_cdf(value, loc, std):
|
|
||||||
return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2)))
|
|
||||||
|
|
||||||
|
|
||||||
def log_discretized_normal(x, mean, logvar, inverse_bin_width):
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7)
|
|
||||||
|
|
||||||
return log_p
|
|
||||||
|
|
||||||
|
|
||||||
def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width):
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
|
|
||||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
|
||||||
|
|
||||||
p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std)
|
|
||||||
|
|
||||||
p = torch.sum(p * pi, dim=-1)
|
|
||||||
|
|
||||||
logp = torch.log(p + 1e-8)
|
|
||||||
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def sample_discretized_normal(mean, logvar, inverse_bin_width):
|
|
||||||
y = torch.randn_like(mean)
|
|
||||||
|
|
||||||
x = torch.exp(0.5 * logvar) * y + mean
|
|
||||||
|
|
||||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width):
|
|
||||||
scale = torch.exp(logscale)
|
|
||||||
|
|
||||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
|
||||||
|
|
||||||
p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \
|
|
||||||
- torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale)
|
|
||||||
|
|
||||||
p = torch.sum(p * pi, dim=-1)
|
|
||||||
|
|
||||||
logp = torch.log(p + 1e-8)
|
|
||||||
|
|
||||||
return logp
|
|
||||||
|
|
||||||
|
|
||||||
def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width):
|
|
||||||
scale = torch.exp(logscale)
|
|
||||||
|
|
||||||
x = x[..., None]
|
|
||||||
|
|
||||||
cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
|
||||||
|
|
||||||
cdf = torch.sum(cdfs * pi, dim=-1)
|
|
||||||
|
|
||||||
return cdf
|
|
||||||
|
|
||||||
|
|
||||||
def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width):
|
|
||||||
# Sample mixtures
|
|
||||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
|
||||||
pi = pi.view(b * c * h * w, n_mixtures)
|
|
||||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
|
||||||
|
|
||||||
# Select mixture params
|
|
||||||
mean = mean.view(b * c * h * w, n_mixtures)
|
|
||||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
|
||||||
logs = logs.view(b * c * h * w, n_mixtures)
|
|
||||||
logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
|
||||||
|
|
||||||
y = torch.rand_like(mean)
|
|
||||||
x = torch.exp(logs) * torch.log(y / (1 - y)) + mean
|
|
||||||
|
|
||||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def log_multinomial(logits, targets):
|
|
||||||
return -F.cross_entropy(logits, targets, reduction='none')
|
|
||||||
|
|
||||||
|
|
||||||
def sample_multinomial(logits):
|
|
||||||
b, n_categories, c, h, w = logits.size()
|
|
||||||
logits = logits.permute(0, 2, 3, 4, 1)
|
|
||||||
p = F.softmax(logits, dim=-1)
|
|
||||||
p = p.view(b * c * h * w, n_categories)
|
|
||||||
x = torch.multinomial(p, num_samples=1).view(b, c, h, w)
|
|
||||||
return x
|
|
||||||
|
|
@ -1,264 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numbers
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data_utils
|
|
||||||
import pickle
|
|
||||||
from scipy.io import loadmat
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import os
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import functional as vf
|
|
||||||
from torch.utils.data import ConcatDataset
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import os
|
|
||||||
import os.path
|
|
||||||
from os.path import join
|
|
||||||
import sys
|
|
||||||
import tarfile
|
|
||||||
|
|
||||||
|
|
||||||
class ToTensorNoNorm():
|
|
||||||
def __call__(self, X_i):
|
|
||||||
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class PadToMultiple(object):
|
|
||||||
def __init__(self, multiple, fill=0, padding_mode='constant'):
|
|
||||||
assert isinstance(multiple, numbers.Number)
|
|
||||||
assert isinstance(fill, (numbers.Number, str, tuple))
|
|
||||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
|
||||||
|
|
||||||
self.multiple = multiple
|
|
||||||
self.fill = fill
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img (PIL Image): Image to be padded.
|
|
||||||
Returns:
|
|
||||||
PIL Image: Padded image.
|
|
||||||
"""
|
|
||||||
w, h = img.size
|
|
||||||
m = self.multiple
|
|
||||||
nw = (w // m + int((w % m) != 0)) * m
|
|
||||||
nh = (h // m + int((h % m) != 0)) * m
|
|
||||||
padw = nw - w
|
|
||||||
padh = nh - h
|
|
||||||
|
|
||||||
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
|
|
||||||
format(self.mulitple, self.fill, self.padding_mode)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTensorDataset(Dataset):
|
|
||||||
"""Dataset wrapping tensors.
|
|
||||||
|
|
||||||
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *tensors, transform=None):
|
|
||||||
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
|
||||||
self.tensors = tensors
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
X, y = self.tensors
|
|
||||||
X_i, y_i, = X[index], y[index]
|
|
||||||
|
|
||||||
if self.transform:
|
|
||||||
X_i = self.transform(X_i)
|
|
||||||
X_i = torch.from_numpy(np.array(X_i, copy=False))
|
|
||||||
X_i = X_i.permute(2, 0, 1)
|
|
||||||
|
|
||||||
return X_i, y_i
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.tensors[0].size(0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_cifar10(args, **kwargs):
|
|
||||||
# set args
|
|
||||||
args.input_size = [3, 32, 32]
|
|
||||||
args.input_type = 'continuous'
|
|
||||||
args.dynamic_binarization = False
|
|
||||||
|
|
||||||
from keras.datasets import cifar10
|
|
||||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
|
||||||
|
|
||||||
x_train = x_train.transpose(0, 3, 1, 2)
|
|
||||||
x_test = x_test.transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
if args.data_augmentation_level == 2:
|
|
||||||
data_transform = transforms.Compose([
|
|
||||||
transforms.ToPILImage(),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
|
|
||||||
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
|
||||||
transforms.CenterCrop(32)
|
|
||||||
])
|
|
||||||
elif args.data_augmentation_level == 1:
|
|
||||||
data_transform = transforms.Compose([
|
|
||||||
transforms.ToPILImage(),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
data_transform = transforms.Compose([
|
|
||||||
transforms.ToPILImage(),
|
|
||||||
])
|
|
||||||
|
|
||||||
x_val = x_train[-10000:]
|
|
||||||
y_val = y_train[-10000:]
|
|
||||||
|
|
||||||
x_train = x_train[:-10000]
|
|
||||||
y_train = y_train[:-10000]
|
|
||||||
|
|
||||||
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
|
|
||||||
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
|
|
||||||
|
|
||||||
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
|
|
||||||
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
|
|
||||||
|
|
||||||
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
|
|
||||||
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
|
|
||||||
|
|
||||||
return train_loader, val_loader, test_loader, args
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tar(tarpath):
|
|
||||||
assert tarpath.endswith('.tar')
|
|
||||||
|
|
||||||
startdir = tarpath[:-4] + '/'
|
|
||||||
|
|
||||||
if os.path.exists(startdir):
|
|
||||||
return startdir
|
|
||||||
|
|
||||||
print('Extracting', tarpath)
|
|
||||||
|
|
||||||
with tarfile.open(name=tarpath) as tar:
|
|
||||||
t = 0
|
|
||||||
done = False
|
|
||||||
while not done:
|
|
||||||
path = join(startdir, 'images{}'.format(t))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
|
|
||||||
print(path)
|
|
||||||
|
|
||||||
for i in range(50000):
|
|
||||||
member = tar.next()
|
|
||||||
|
|
||||||
if member is None:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Skip directories
|
|
||||||
while member.isdir():
|
|
||||||
member = tar.next()
|
|
||||||
if member is None:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
|
|
||||||
member.name = member.name.split('/')[-1]
|
|
||||||
|
|
||||||
tar.extract(member, path=path)
|
|
||||||
|
|
||||||
t += 1
|
|
||||||
|
|
||||||
return startdir
|
|
||||||
|
|
||||||
|
|
||||||
def load_imagenet(resolution, args, **kwargs):
|
|
||||||
assert resolution == 32 or resolution == 64
|
|
||||||
|
|
||||||
args.input_size = [3, resolution, resolution]
|
|
||||||
|
|
||||||
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
|
|
||||||
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
|
|
||||||
|
|
||||||
trainpath = extract_tar(trainpath)
|
|
||||||
valpath = extract_tar(valpath)
|
|
||||||
|
|
||||||
data_transform = transforms.Compose([
|
|
||||||
ToTensorNoNorm()
|
|
||||||
])
|
|
||||||
|
|
||||||
print('Starting loading ImageNet')
|
|
||||||
|
|
||||||
imagenet_data = torchvision.datasets.ImageFolder(
|
|
||||||
trainpath,
|
|
||||||
transform=data_transform)
|
|
||||||
|
|
||||||
print('Number of data images', len(imagenet_data))
|
|
||||||
|
|
||||||
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
|
|
||||||
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
|
|
||||||
|
|
||||||
train_dataset = torch.utils.data.dataset.Subset(
|
|
||||||
imagenet_data, train_idcs)
|
|
||||||
val_dataset = torch.utils.data.dataset.Subset(
|
|
||||||
imagenet_data, val_idcs)
|
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
val_loader = torch.utils.data.DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
test_dataset = torchvision.datasets.ImageFolder(
|
|
||||||
valpath,
|
|
||||||
transform=data_transform)
|
|
||||||
|
|
||||||
print('Number of val images:', len(test_dataset))
|
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return train_loader, val_loader, test_loader, args
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args, **kwargs):
|
|
||||||
|
|
||||||
if args.dataset == 'cifar10':
|
|
||||||
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
|
|
||||||
elif args.dataset == 'imagenet32':
|
|
||||||
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
|
|
||||||
elif args.dataset == 'imagenet64':
|
|
||||||
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise Exception('Wrong name of the dataset!')
|
|
||||||
|
|
||||||
return train_loader, val_loader, test_loader, args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
class Args():
|
|
||||||
def __init__(self):
|
|
||||||
self.batch_size = 128
|
|
||||||
train_loader, val_loader, test_loader, args = load_imagenet32(Args())
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import numpy as np
|
|
||||||
from scipy.misc import logsumexp
|
|
||||||
from optimization.loss import calculate_loss_array
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_likelihood(X, model, args, S=5000, MB=500):
|
|
||||||
|
|
||||||
# set auxiliary variables for number of training and test sets
|
|
||||||
N_test = X.size(0)
|
|
||||||
|
|
||||||
X = X.view(-1, *args.input_size)
|
|
||||||
|
|
||||||
likelihood_test = []
|
|
||||||
|
|
||||||
if S <= MB:
|
|
||||||
R = 1
|
|
||||||
else:
|
|
||||||
R = S // MB
|
|
||||||
S = MB
|
|
||||||
|
|
||||||
for j in range(N_test):
|
|
||||||
if j % 100 == 0:
|
|
||||||
print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100))
|
|
||||||
|
|
||||||
x_single = X[j].unsqueeze(0)
|
|
||||||
|
|
||||||
a = []
|
|
||||||
for r in range(0, R):
|
|
||||||
# Repeat it for all training points
|
|
||||||
x = x_single.expand(S, *x_single.size()[1:]).contiguous()
|
|
||||||
|
|
||||||
x_mean, z_mu, z_var, ldj, z0, zk = model(x)
|
|
||||||
|
|
||||||
a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args)
|
|
||||||
|
|
||||||
a.append(-a_tmp.cpu().data.numpy())
|
|
||||||
|
|
||||||
# calculate max
|
|
||||||
a = np.asarray(a)
|
|
||||||
a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
|
|
||||||
likelihood_x = logsumexp(a)
|
|
||||||
likelihood_test.append(likelihood_x - np.log(len(a)))
|
|
||||||
|
|
||||||
likelihood_test = np.array(likelihood_test)
|
|
||||||
|
|
||||||
nll = -np.mean(likelihood_test)
|
|
||||||
|
|
||||||
if args.input_type == 'multinomial':
|
|
||||||
bpd = nll/(np.prod(args.input_size) * np.log(2.))
|
|
||||||
elif args.input_type == 'binary':
|
|
||||||
bpd = 0.
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid input type!')
|
|
||||||
|
|
||||||
return nll, bpd
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib
|
|
||||||
# noninteractive background
|
|
||||||
matplotlib.use('Agg')
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None):
|
|
||||||
"""
|
|
||||||
Plots train_loss and validation loss as a function of optimization iteration
|
|
||||||
:param train_loss: np.array of train_loss (1D or 2D)
|
|
||||||
:param validation_loss: np.array of validation loss (1D or 2D)
|
|
||||||
:param fname: output file name
|
|
||||||
:param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied
|
|
||||||
accross training curves.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
matplotlib.rcParams.update({'font.size': 14})
|
|
||||||
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
|
||||||
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
|
||||||
|
|
||||||
if len(train_loss.shape) == 1:
|
|
||||||
# Single training curve
|
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
|
||||||
figsize = (6, 4)
|
|
||||||
|
|
||||||
if train_loss.shape[0] == validation_loss.shape[0]:
|
|
||||||
# validation score evaluated every iteration
|
|
||||||
x = np.arange(train_loss.shape[0])
|
|
||||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
|
||||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
|
||||||
|
|
||||||
elif train_loss.shape[0] % validation_loss.shape[0] == 0:
|
|
||||||
# validation score evaluated every epoch
|
|
||||||
x = np.arange(train_loss.shape[0])
|
|
||||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
|
||||||
|
|
||||||
x = np.arange(validation_loss.shape[0])
|
|
||||||
x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0]
|
|
||||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
|
||||||
else:
|
|
||||||
raise ValueError('Length of train_loss and validation_loss must be equal or divisible')
|
|
||||||
|
|
||||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
|
||||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
|
||||||
ax.set_ylim([miny, maxy])
|
|
||||||
|
|
||||||
elif len(train_loss.shape) == 2:
|
|
||||||
# Multiple training curves
|
|
||||||
|
|
||||||
cmap = plt.cm.brg
|
|
||||||
|
|
||||||
cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0])
|
|
||||||
scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
|
||||||
figsize = (6, 4)
|
|
||||||
|
|
||||||
if labels is None:
|
|
||||||
labels = ['%d' % i for i in range(train_loss.shape[0])]
|
|
||||||
|
|
||||||
if train_loss.shape[1] == validation_loss.shape[1]:
|
|
||||||
for i in range(train_loss.shape[0]):
|
|
||||||
color_val = scalarMap.to_rgba(i)
|
|
||||||
|
|
||||||
# validation score evaluated every iteration
|
|
||||||
x = np.arange(train_loss.shape[0])
|
|
||||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
|
||||||
ax.plot(x, validation_loss[i], '--', lw=2., color=color_val)
|
|
||||||
|
|
||||||
elif train_loss.shape[1] % validation_loss.shape[1] == 0:
|
|
||||||
for i in range(train_loss.shape[0]):
|
|
||||||
color_val = scalarMap.to_rgba(i)
|
|
||||||
|
|
||||||
# validation score evaluated every epoch
|
|
||||||
x = np.arange(train_loss.shape[1])
|
|
||||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
|
||||||
|
|
||||||
x = np.arange(validation_loss.shape[1])
|
|
||||||
x = (x+1) * train_loss.shape[1] / validation_loss.shape[1]
|
|
||||||
ax.plot(x, validation_loss[i], '-', lw=2., color=color_val)
|
|
||||||
|
|
||||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
|
||||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
|
||||||
ax.set_ylim([miny, maxy])
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError('train_loss and validation_loss must be 1D or 2D arrays')
|
|
||||||
|
|
||||||
ax.set_xlabel('iteration')
|
|
||||||
ax.set_ylabel('loss')
|
|
||||||
plt.title('Training and validation loss')
|
|
||||||
|
|
||||||
fig.set_size_inches(figsize)
|
|
||||||
fig.subplots_adjust(hspace=0.1)
|
|
||||||
plt.savefig(fname, bbox_inches='tight')
|
|
||||||
|
|
||||||
plt.close()
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import imageio
|
|
||||||
|
|
||||||
|
|
||||||
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
|
|
||||||
if epoch == 1:
|
|
||||||
if not os.path.exists(args.snap_dir + 'reconstruction/'):
|
|
||||||
os.makedirs(args.snap_dir + 'reconstruction/')
|
|
||||||
if loss_type == 'bpd':
|
|
||||||
fname = str(epoch) + '_bpd_%5.3f' % loss
|
|
||||||
elif loss_type == 'elbo':
|
|
||||||
fname = str(epoch) + '_elbo_%6.4f' % loss
|
|
||||||
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
|
|
||||||
batch, channels, height, width = x_sample.shape
|
|
||||||
|
|
||||||
print(x_sample.shape)
|
|
||||||
|
|
||||||
mosaic = np.zeros((height * size_y, width * size_x, channels))
|
|
||||||
|
|
||||||
for j in range(size_y):
|
|
||||||
for i in range(size_x):
|
|
||||||
idx = j * size_x + i
|
|
||||||
|
|
||||||
image = x_sample[idx]
|
|
||||||
|
|
||||||
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
|
|
||||||
image.transpose(1, 2, 0)
|
|
||||||
|
|
||||||
# Remove channel for BW images
|
|
||||||
mosaic = mosaic.squeeze()
|
|
||||||
|
|
||||||
imageio.imwrite(dir + file_name + '.png', mosaic)
|
|
||||||
|
|
@ -89,7 +89,7 @@ def main():
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
print("Loading the model...")
|
print("Loading the models...")
|
||||||
model = torch.load(args.model_path)
|
model = torch.load(args.model_path)
|
||||||
|
|
||||||
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
|
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
|
||||||
1
models/cnn/__init__.py
Normal file
1
models/cnn/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .cnn import CNNPredictor
|
||||||
|
|
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .train import train
|
from .train import train
|
||||||
from utils import print_losses
|
from ..utils import print_losses
|
||||||
|
|
||||||
class FullTrainer(Trainer):
|
class FullTrainer(Trainer):
|
||||||
def execute(
|
def execute(
|
||||||
|
|
@ -7,7 +7,7 @@ from torch import nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from model.cnn import CNNPredictor
|
from ..models.cnn import CNNPredictor
|
||||||
from .train import train
|
from .train import train
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -59,4 +59,4 @@ class OptunaTrainer(Trainer):
|
||||||
best_model = CNNPredictor(
|
best_model = CNNPredictor(
|
||||||
**best_params
|
**best_params
|
||||||
)
|
)
|
||||||
torch.save(best_model, "models/final_model.pt")
|
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||||
Reference in a new issue