Merge branch 'main' into portability
This commit is contained in:
commit
013afe832f
44 changed files with 5 additions and 2834 deletions
|
|
@ -1,19 +0,0 @@
|
|||
Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Integer Discrete Flows and Lossless Compression
|
||||
|
||||
This repository contains the code for the experiments presented in [1].
|
||||
|
||||
## Usage
|
||||
|
||||
### CIFAR10 setup:
|
||||
```
|
||||
python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type ‘densenet’
|
||||
```
|
||||
|
||||
|
||||
### ImageNet32 setup:
|
||||
```
|
||||
python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99
|
||||
```
|
||||
|
||||
|
||||
### ImageNet64 setup:
|
||||
```
|
||||
python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64
|
||||
```
|
||||
|
||||
# Acknowledgements
|
||||
The Robert Bosch GmbH is acknowledged for financial support.
|
||||
|
||||
# References
|
||||
[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019).
|
||||
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
import numpy as np
|
||||
from . import rans
|
||||
from utils.distributions import discretized_logistic_cdf, \
|
||||
mixture_discretized_logistic_cdf
|
||||
import torch
|
||||
|
||||
precision = 24
|
||||
n_bins = 4096
|
||||
|
||||
|
||||
def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width):
|
||||
if variable_type == 'discrete':
|
||||
if distribution_type == 'logistic':
|
||||
if len(pz) == 2:
|
||||
return discretized_logistic_cdf(
|
||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||
elif len(pz) == 3:
|
||||
return mixture_discretized_logistic_cdf(
|
||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||
elif distribution_type == 'normal':
|
||||
pass
|
||||
|
||||
elif variable_type == 'continuous':
|
||||
if distribution_type == 'logistic':
|
||||
pass
|
||||
elif distribution_type == 'normal':
|
||||
pass
|
||||
elif distribution_type == 'steplogistic':
|
||||
pass
|
||||
raise ValueError
|
||||
|
||||
|
||||
def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
||||
mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2]
|
||||
MEAN = torch.round(mean / bin_width).long()
|
||||
|
||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||
bin_locations = bin_locations.float() * bin_width
|
||||
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
||||
|
||||
pz = [param[:, :, :, :, None] for param in pz]
|
||||
cdf = cdf_fn(
|
||||
bin_locations - bin_width,
|
||||
pz,
|
||||
variable_type,
|
||||
distribution_type,
|
||||
1./bin_width).cpu().numpy()
|
||||
|
||||
# Compute CDFs, reweigh to give all bins at least
|
||||
# 1 / (2^precision) probability.
|
||||
# CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins)
|
||||
CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \
|
||||
+ np.arange(n_bins)
|
||||
|
||||
return CDFs, MEAN
|
||||
|
||||
|
||||
def encode_sample(
|
||||
z, pz, variable_type, distribution_type, bin_width=1./256, state=None):
|
||||
if state is None:
|
||||
state = rans.x_init
|
||||
else:
|
||||
state = rans.unflatten(state)
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
||||
# z is transformed to Z to match the indices for the CDFs array
|
||||
Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN
|
||||
Z = Z.cpu().numpy()
|
||||
|
||||
if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)):
|
||||
print('Z out of allowed range of values, canceling compression')
|
||||
return None
|
||||
|
||||
Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy()
|
||||
for symbol, cdf in zip(Z[::-1], CDFs[::-1]):
|
||||
statfun = statfun_encode(cdf)
|
||||
state = rans.append_symbol(statfun, precision)(state, symbol)
|
||||
|
||||
state = rans.flatten(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def decode_sample(
|
||||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||
state = rans.unflatten(state)
|
||||
|
||||
device = pz[0].DEVICE
|
||||
size = pz[0].size()[0:4]
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
||||
CDFs = CDFs.reshape(-1, n_bins)
|
||||
result = np.zeros(len(CDFs), dtype=int)
|
||||
for i, cdf in enumerate(CDFs):
|
||||
statfun = statfun_decode(cdf)
|
||||
state, symbol = rans.pop_symbol(statfun, precision)(state)
|
||||
result[i] = symbol
|
||||
|
||||
Z_flat = torch.from_numpy(result).to(device)
|
||||
Z = Z_flat.view(size) - n_bins // 2 + MEAN
|
||||
|
||||
z = Z.float() * bin_width
|
||||
|
||||
state = rans.flatten(state)
|
||||
|
||||
return state, z
|
||||
|
||||
|
||||
def statfun_encode(CDF):
|
||||
def _statfun_encode(symbol):
|
||||
return CDF[symbol], CDF[symbol + 1] - CDF[symbol]
|
||||
return _statfun_encode
|
||||
|
||||
|
||||
def statfun_decode(CDF):
|
||||
def _statfun_decode(cf):
|
||||
# Search such that CDF[s] <= cf < CDF[s]
|
||||
s = np.searchsorted(CDF, cf, side='right')
|
||||
s = s - 1
|
||||
start, freq = statfun_encode(CDF)(s)
|
||||
return s, (start, freq)
|
||||
return _statfun_decode
|
||||
|
||||
|
||||
def encode(x, symbol):
|
||||
return rans.append_symbol(statfun_encode, precision)(x, symbol)
|
||||
|
||||
|
||||
def decode(x):
|
||||
return rans.pop_symbol(statfun_decode, precision)(x)
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
"""
|
||||
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h
|
||||
|
||||
ROUGH GUIDE:
|
||||
We use the pythonic names 'append' and 'pop' for encoding and decoding
|
||||
respectively. The compressed state 'x' is an immutable stack, implemented using
|
||||
a cons list.
|
||||
|
||||
x: the current stack-like state of the encoder/decoder.
|
||||
|
||||
precision: the natural numbers are divided into ranges of size 2^precision.
|
||||
|
||||
start & freq: start indicates the beginning of the range in [0, 2^precision-1]
|
||||
that the current symbol is represented by. freq is the length of the range.
|
||||
freq is chosen such that p(symbol) ~= freq/2^precision.
|
||||
"""
|
||||
import numpy as np
|
||||
from functools import reduce
|
||||
|
||||
|
||||
rans_l = 1 << 31 # the lower bound of the normalisation interval
|
||||
tail_bits = (1 << 32) - 1
|
||||
|
||||
x_init = (rans_l, ())
|
||||
|
||||
def append(x, start, freq, precision):
|
||||
"""Encodes a symbol with range [start, start + freq). All frequencies are
|
||||
assumed to sum to "1 << precision", and the resulting bits get written to
|
||||
x."""
|
||||
if x[0] >= ((rans_l >> precision) << 32) * freq:
|
||||
x = (x[0] >> 32, (x[0] & tail_bits, x[1]))
|
||||
return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1]
|
||||
|
||||
def pop(x_, precision):
|
||||
"""Advances in the bit stream by "popping" a single symbol with range start
|
||||
"start" and frequency "freq"."""
|
||||
cf = x_[0] & ((1 << precision) - 1)
|
||||
def pop(start, freq):
|
||||
x = freq * (x_[0] >> precision) + cf - start, x_[1]
|
||||
return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x
|
||||
return cf, pop
|
||||
|
||||
def append_symbol(statfun, precision):
|
||||
def append_(x, symbol):
|
||||
start, freq = statfun(symbol)
|
||||
return append(x, start, freq, precision)
|
||||
return append_
|
||||
|
||||
def pop_symbol(statfun, precision):
|
||||
def pop_(x):
|
||||
cf, pop_fun = pop(x, precision)
|
||||
symbol, (start, freq) = statfun(cf)
|
||||
return pop_fun(start, freq), symbol
|
||||
return pop_
|
||||
|
||||
def flatten(x):
|
||||
"""Flatten a rans state x into a 1d numpy array."""
|
||||
out, x = [x[0] >> 32, x[0]], x[1]
|
||||
while x:
|
||||
x_head, x = x
|
||||
out.append(x_head)
|
||||
return np.asarray(out, dtype=np.uint32)
|
||||
|
||||
def unflatten(arr):
|
||||
"""Unflatten a 1d numpy array into a rans state."""
|
||||
return (int(arr[0]) << 32 | int(arr[1]),
|
||||
reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))
|
||||
|
|
@ -1,188 +0,0 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
from utils.load_data import load_dataset
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||
|
||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||
metavar='DATASET',
|
||||
help='Dataset choice.')
|
||||
|
||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
|
||||
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
||||
|
||||
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
||||
help='Evaluate per how many epochs')
|
||||
|
||||
|
||||
# optimization settings
|
||||
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
||||
help='number of epochs to train (default: 2000)')
|
||||
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
||||
help='number of early stopping epochs')
|
||||
|
||||
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
|
||||
help='input batch size for training (default: 100)')
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
||||
help='learning rate')
|
||||
parser.add_argument('--warmup', type=int, default=10,
|
||||
help='number of warmup epochs')
|
||||
|
||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||
help='data augmentation level')
|
||||
|
||||
parser.add_argument('--no_decode', action='store_true', default=False,
|
||||
help='disables decoding')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
|
||||
|
||||
def encode_images(img, model, decode):
|
||||
batchsize, img_c, img_h, img_w = img.size()
|
||||
c, h, w = model.args.input_size
|
||||
|
||||
assert img_h == img_w and h == w
|
||||
|
||||
if img_h != h:
|
||||
assert img_h % h == 0
|
||||
steps = img_h // h
|
||||
|
||||
states = [[] for i in range(batchsize)]
|
||||
state_sizes = [0 for i in range(batchsize)]
|
||||
bpd = [0 for i in range(batchsize)]
|
||||
error = 0
|
||||
|
||||
for j in range(steps):
|
||||
for i in range(steps):
|
||||
r = encode_patches(
|
||||
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
|
||||
for b in range(batchsize):
|
||||
|
||||
if r[0][b] is None:
|
||||
states[b].append(None)
|
||||
else:
|
||||
states[b].extend(r[0][b])
|
||||
state_sizes[b] += r[1][b]
|
||||
bpd[b] += r[2][b] / steps**2
|
||||
error += r[3]
|
||||
return states, state_sizes, bpd, error
|
||||
else:
|
||||
return encode_patches(img, model, decode)
|
||||
|
||||
|
||||
def encode_patches(imgs, model, decode):
|
||||
batchsize, img_c, img_h, img_w = imgs.size()
|
||||
c, h, w = model.args.input_size
|
||||
assert img_h == h and img_w == w
|
||||
|
||||
states = model.encode(imgs)
|
||||
|
||||
bpd = model.forward(imgs)[1].cpu().numpy()
|
||||
|
||||
state_sizes = []
|
||||
error = 0
|
||||
|
||||
for b in range(batchsize):
|
||||
if states[b] is None:
|
||||
# Using escape bit ;)
|
||||
state_sizes += [8 * img_c * img_h * img_w + 1]
|
||||
|
||||
# Error remains unchanged.
|
||||
print('Escaping, not encoding.')
|
||||
|
||||
else:
|
||||
if decode:
|
||||
x_recon = model.decode([states[b]])
|
||||
|
||||
error += torch.sum(
|
||||
torch.abs(x_recon.int() - imgs[b].int())).item()
|
||||
|
||||
# Append state plus an escape bit
|
||||
state_sizes += [32 * len(states[b]) + 1]
|
||||
|
||||
return states, state_sizes, bpd, error
|
||||
|
||||
|
||||
def run(args, kwargs):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
args.snap_dir = snap_dir = \
|
||||
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
||||
|
||||
# ==================================================================================================================
|
||||
# SNAPSHOTS
|
||||
# ==================================================================================================================
|
||||
|
||||
# ==================================================================================================================
|
||||
# LOAD DATA
|
||||
# ==================================================================================================================
|
||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||
|
||||
final_model = torch.load(snap_dir + 'a.model')
|
||||
if hasattr(final_model, 'module'):
|
||||
final_model = final_model.module
|
||||
final_model = final_model.cuda()
|
||||
|
||||
sizes = []
|
||||
errors = []
|
||||
bpds = []
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
t = 0
|
||||
with torch.no_grad():
|
||||
for data, _ in test_loader:
|
||||
if args.cuda:
|
||||
data = data.cuda()
|
||||
|
||||
state, state_sizes, bpd, error = \
|
||||
encode_images(data, final_model, decode=not args.no_decode)
|
||||
|
||||
errors += [error]
|
||||
bpds.extend(bpd)
|
||||
sizes.extend(state_sizes)
|
||||
|
||||
t += len(data)
|
||||
|
||||
print(
|
||||
'Examples: {}/{} bpd compression: {:.3f} error: {},'
|
||||
' analytical bpd {:.3f}'.format(
|
||||
t, len(test_loader.dataset),
|
||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||
np.sum(errors),
|
||||
np.mean(bpds)
|
||||
))
|
||||
|
||||
if args.no_decode:
|
||||
print('Not testing decoding.')
|
||||
else:
|
||||
print('Error: {}'.format(np.sum(errors)))
|
||||
|
||||
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
|
||||
print('Final bpd: {:.3f} error: {}'.format(
|
||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||
np.sum(errors)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
run(args, kwargs)
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
import os
|
||||
from optimization.training import train, evaluate
|
||||
from utils.load_data import load_dataset
|
||||
from utils.plotting import plot_training_curve
|
||||
import imageio
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||
|
||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||
metavar='DATASET',
|
||||
help='Dataset choice.')
|
||||
|
||||
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
||||
help='input batch size for training (default: 100)')
|
||||
|
||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||
help='data augmentation level')
|
||||
|
||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
|
||||
|
||||
def run(args, kwargs):
|
||||
|
||||
args.snap_dir = snap_dir = \
|
||||
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
||||
|
||||
# ==================================================================================================================
|
||||
# SNAPSHOTS
|
||||
# ==================================================================================================================
|
||||
|
||||
# ==================================================================================================================
|
||||
# LOAD DATA
|
||||
# ==================================================================================================================
|
||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||
|
||||
final_model = torch.load(snap_dir + 'a.model')
|
||||
if hasattr(final_model, 'module'):
|
||||
final_model = final_model.module
|
||||
|
||||
from models.backround import SmoothRound
|
||||
for module in final_model.modules():
|
||||
if isinstance(module, SmoothRound):
|
||||
module._round_decay = 1.
|
||||
|
||||
exp_dir = snap_dir + 'partials/'
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
|
||||
images = []
|
||||
with torch.no_grad():
|
||||
for data, _ in test_loader:
|
||||
|
||||
if args.cuda:
|
||||
data = data.cuda()
|
||||
|
||||
for i in range(len(data)):
|
||||
_, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1])
|
||||
|
||||
for j in range(len(ys) + 1):
|
||||
x_recon = final_model.inverse(
|
||||
z,
|
||||
ys[len(ys) - j:])
|
||||
|
||||
images.append(x_recon.float())
|
||||
|
||||
if i == 10:
|
||||
break
|
||||
break
|
||||
|
||||
for j in range(len(ys) + 1):
|
||||
|
||||
grid = make_grid(
|
||||
torch.stack(images[j::len(ys) + 1], dim=0).squeeze(),
|
||||
nrow=11, padding=0,
|
||||
normalize=True, range=None,
|
||||
scale_each=False, pad_value=0)
|
||||
|
||||
imageio.imwrite(
|
||||
exp_dir + 'loaded{j}.png'.format(j=j),
|
||||
grid.cpu().numpy().transpose(1, 2, 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
run(args, kwargs)
|
||||
|
|
@ -1,285 +0,0 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
|
||||
import os
|
||||
|
||||
import datetime
|
||||
|
||||
from optimization.training import train, evaluate
|
||||
from utils.load_data import load_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||
|
||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10',
|
||||
choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||
metavar='DATASET',
|
||||
help='Dataset choice.')
|
||||
|
||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
|
||||
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
||||
|
||||
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
||||
help='Evaluate per how many epochs')
|
||||
|
||||
parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
|
||||
help='output directory for model snapshots etc.')
|
||||
|
||||
fp = parser.add_mutually_exclusive_group(required=False)
|
||||
fp.add_argument('-te', '--testing', action='store_true', dest='testing',
|
||||
help='evaluate on test set after training')
|
||||
fp.add_argument('-va', '--validation', action='store_false', dest='testing',
|
||||
help='only evaluate on validation set')
|
||||
parser.set_defaults(testing=True)
|
||||
|
||||
# optimization settings
|
||||
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
||||
help='number of epochs to train (default: 2000)')
|
||||
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
||||
help='number of early stopping epochs')
|
||||
|
||||
parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE',
|
||||
help='input batch size for training (default: 100)')
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
||||
help='learning rate')
|
||||
parser.add_argument('--warmup', type=int, default=10,
|
||||
help='number of warmup epochs')
|
||||
|
||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||
help='data augmentation level')
|
||||
|
||||
parser.add_argument('--variable_type', type=str, default='discrete',
|
||||
help='variable type of data distribution: discrete/continuous',
|
||||
choices=['discrete', 'continuous'])
|
||||
parser.add_argument('--distribution_type', type=str, default='logistic',
|
||||
choices=['logistic', 'normal', 'steplogistic'],
|
||||
help='distribution type: logistic/normal')
|
||||
parser.add_argument('--n_flows', type=int, default=8,
|
||||
help='number of flows per level')
|
||||
parser.add_argument('--n_levels', type=int, default=3,
|
||||
help='number of levels')
|
||||
|
||||
parser.add_argument('--n_bits', type=int, default=8,
|
||||
help='')
|
||||
|
||||
# ---------------- SETTINGS CONCERNING NETWORKS -------------
|
||||
parser.add_argument('--densenet_depth', type=int, default=8,
|
||||
help='Depth of densenets')
|
||||
parser.add_argument('--n_channels', type=int, default=512,
|
||||
help='number of channels in coupling and splitprior')
|
||||
# ---------------- ----------------------------- -------------
|
||||
|
||||
|
||||
# ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
|
||||
parser.add_argument('--coupling_type', type=str, default='shallow',
|
||||
choices=['shallow', 'resnet', 'densenet'],
|
||||
help='Type of coupling layer')
|
||||
parser.add_argument('--splitfactor', default=0, type=int,
|
||||
help='Split factor for coupling layers.')
|
||||
|
||||
parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
|
||||
help='Split coupling layer on quarter')
|
||||
parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
|
||||
parser.set_defaults(split_quarter=True)
|
||||
# ---------------- ----------------------------------- -------------
|
||||
|
||||
|
||||
# ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
|
||||
parser.add_argument('--splitprior_type', type=str, default='shallow',
|
||||
choices=['none', 'shallow', 'resnet', 'densenet'],
|
||||
help='Type of splitprior. Use \'none\' for no splitprior')
|
||||
# ---------------- ------------------------------- -------------
|
||||
|
||||
|
||||
# ---------------- SETTINGS CONCERNING PRIORS -------------
|
||||
parser.add_argument('--n_mixtures', type=int, default=1,
|
||||
help='number of mixtures')
|
||||
# ---------------- ------------------------------- -------------
|
||||
|
||||
parser.add_argument('--hard_round', dest='hard_round', action='store_true',
|
||||
help='Rounding of translation in discrete models. Weird '
|
||||
'probabilistic implications, only for experimental phase')
|
||||
parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
|
||||
parser.set_defaults(hard_round=True)
|
||||
|
||||
parser.add_argument('--round_approx', type=str, default='smooth',
|
||||
choices=['smooth', 'stochastic'])
|
||||
|
||||
parser.add_argument('--lr_decay', default=0.999, type=float,
|
||||
help='Learning rate')
|
||||
|
||||
parser.add_argument('--temperature', default=1.0, type=float,
|
||||
help='Temperature used for BackRound. It is used in '
|
||||
'the the SmoothRound module. '
|
||||
'(default=1.0')
|
||||
|
||||
# gpu/cpu
|
||||
parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU',
|
||||
help='choose GPU to run on.')
|
||||
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
if args.manual_seed is None:
|
||||
args.manual_seed = random.randint(1, 100000)
|
||||
random.seed(args.manual_seed)
|
||||
torch.manual_seed(args.manual_seed)
|
||||
np.random.seed(args.manual_seed)
|
||||
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
|
||||
|
||||
def run(args, kwargs):
|
||||
|
||||
print('\nMODEL SETTINGS: \n', args, '\n')
|
||||
print("Random Seed: ", args.manual_seed)
|
||||
|
||||
if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5:
|
||||
args.evaluate_interval_epochs = 5
|
||||
|
||||
# ==================================================================================================================
|
||||
# SNAPSHOTS
|
||||
# ==================================================================================================================
|
||||
args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
|
||||
args.model_signature = args.model_signature.replace(':', '_')
|
||||
|
||||
snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset)
|
||||
snap_dir = snapshots_path
|
||||
|
||||
snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels)
|
||||
|
||||
snap_dir = snap_dir + '__' + args.model_signature + '/'
|
||||
|
||||
args.snap_dir = snap_dir
|
||||
|
||||
if not os.path.exists(snap_dir):
|
||||
os.makedirs(snap_dir)
|
||||
|
||||
with open(snap_dir + 'log.txt', 'a') as ff:
|
||||
print('\nMODEL SETTINGS: \n', args, '\n', file=ff)
|
||||
|
||||
# SAVING
|
||||
torch.save(args, snap_dir + '.config')
|
||||
|
||||
# ==================================================================================================================
|
||||
# LOAD DATA
|
||||
# ==================================================================================================================
|
||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||
|
||||
# ==================================================================================================================
|
||||
# SELECT MODEL
|
||||
# ==================================================================================================================
|
||||
# flow parameters and architecture choice are passed on to model through args
|
||||
print(args.input_size)
|
||||
|
||||
import models.Model as Model
|
||||
|
||||
model = Model.Model(args)
|
||||
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model.set_temperature(args.temperature)
|
||||
model.enable_hard_round(args.hard_round)
|
||||
|
||||
model_sample = model
|
||||
|
||||
# ====================================
|
||||
# INIT
|
||||
# ====================================
|
||||
# data dependend initialization on CPU
|
||||
for batch_idx, (data, _) in enumerate(train_loader):
|
||||
model(data)
|
||||
break
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = torch.nn.DataParallel(model, dim=0)
|
||||
|
||||
model.to(args.DEVICE)
|
||||
|
||||
def lr_lambda(epoch):
|
||||
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
||||
optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
|
||||
|
||||
# ==================================================================================================================
|
||||
# TRAINING
|
||||
# ==================================================================================================================
|
||||
train_bpd = []
|
||||
val_bpd = []
|
||||
|
||||
# for early stopping
|
||||
best_val_bpd = np.inf
|
||||
best_train_bpd = np.inf
|
||||
epoch = 0
|
||||
|
||||
train_times = []
|
||||
|
||||
model.eval()
|
||||
model.train()
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
t_start = time.time()
|
||||
scheduler.step()
|
||||
tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args)
|
||||
train_bpd.append(tr_bpd)
|
||||
train_times.append(time.time()-t_start)
|
||||
print('One training epoch took %.2f seconds' % (time.time()-t_start))
|
||||
|
||||
if epoch < 25 or epoch % args.evaluate_interval_epochs == 0:
|
||||
v_loss, v_bpd = evaluate(
|
||||
train_loader, val_loader, model, model_sample, args,
|
||||
epoch=epoch, file=snap_dir + 'log.txt')
|
||||
|
||||
val_bpd.append(v_bpd)
|
||||
|
||||
# Model save based on TRAIN performance (is heavily correlated with validation performance.)
|
||||
if np.mean(tr_bpd) < best_train_bpd:
|
||||
best_train_bpd = np.mean(tr_bpd)
|
||||
best_val_bpd = v_bpd
|
||||
torch.save(model.module, snap_dir + 'a.model')
|
||||
torch.save(optimizer, snap_dir + 'a.optimizer')
|
||||
print('->model saved<-')
|
||||
|
||||
print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format(
|
||||
best_train_bpd, best_val_bpd))
|
||||
|
||||
if math.isnan(v_loss):
|
||||
raise ValueError('NaN encountered!')
|
||||
|
||||
train_bpd = np.hstack(train_bpd)
|
||||
val_bpd = np.array(val_bpd)
|
||||
|
||||
# training time per epoch
|
||||
train_times = np.array(train_times)
|
||||
mean_train_time = np.mean(train_times)
|
||||
std_train_time = np.std(train_times, ddof=1)
|
||||
print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time))
|
||||
|
||||
# ==================================================================================================================
|
||||
# EVALUATION
|
||||
# ==================================================================================================================
|
||||
final_model = torch.load(snap_dir + 'a.model')
|
||||
test_loss, test_bpd = evaluate(
|
||||
train_loader, test_loader, final_model, final_model, args,
|
||||
epoch=epoch, file=snap_dir + 'test_log.txt')
|
||||
|
||||
print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
run(args, kwargs)
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -1,191 +0,0 @@
|
|||
import torch
|
||||
import models.generative_flows as generative_flows
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .priors import Prior
|
||||
from optimization.loss import compute_loss_array
|
||||
from coding.coder import encode_sample, decode_sample
|
||||
|
||||
|
||||
class Normalize(Base):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.n_bits = args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.input_size = args.input_size
|
||||
|
||||
def forward(self, x, ldj, reverse=False):
|
||||
domain = 2.**self.n_bits
|
||||
|
||||
if self.variable_type == 'discrete':
|
||||
# Discrete variables will be measured on intervals sized 1/domain.
|
||||
# Hence, there is no need to change the log Jacobian determinant.
|
||||
dldj = 0
|
||||
elif self.variable_type == 'continuous':
|
||||
dldj = -np.log(domain) * np.prod(self.input_size)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if not reverse:
|
||||
x = (x - domain / 2) / domain
|
||||
ldj += dldj
|
||||
else:
|
||||
x = x * domain + domain / 2
|
||||
ldj -= dldj
|
||||
|
||||
return x, ldj
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
The base VAE class containing gated convolutional encoder and decoder
|
||||
architecture. Can be used as a base class for VAE's with normalizing flows.
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
|
||||
n_channels, height, width = args.input_size
|
||||
|
||||
self.normalize = Normalize(args)
|
||||
|
||||
self.flow = generative_flows.GenerativeFlow(
|
||||
n_channels, height, width, args)
|
||||
|
||||
self.n_bits = args.n_bits
|
||||
|
||||
self.z_size = self.flow.z_size
|
||||
|
||||
self.prior = Prior(self.z_size, args)
|
||||
|
||||
def dequantize(self, x):
|
||||
if self.training:
|
||||
x = x + torch.rand_like(x)
|
||||
else:
|
||||
# Required for stability.
|
||||
alpha = 1e-3
|
||||
x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha)
|
||||
|
||||
return x
|
||||
|
||||
def loss(self, pz, z, pys, ys, ldj):
|
||||
batchsize = z.size(0)
|
||||
loss, bpd, bpd_per_prior = \
|
||||
compute_loss_array(pz, z, pys, ys, ldj, self.args)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'auxillary_loss'):
|
||||
loss += module.auxillary_loss() / batchsize
|
||||
|
||||
return loss, bpd, bpd_per_prior
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Evaluates the model as a whole, encodes and decodes. Note that the log
|
||||
det jacobian is zero for a plain VAE (without flows), and z_0 = z_k.
|
||||
"""
|
||||
# Decode z to x.
|
||||
|
||||
assert x.dtype == torch.uint8
|
||||
|
||||
x = x.float()
|
||||
|
||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||
if self.variable_type == 'continuous':
|
||||
x = self.dequantize(x)
|
||||
elif self.variable_type == 'discrete':
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
x, ldj = self.normalize(x, ldj)
|
||||
|
||||
z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=())
|
||||
|
||||
pz, z, ldj = self.prior(z, ldj)
|
||||
|
||||
loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj)
|
||||
|
||||
return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj
|
||||
|
||||
def inverse(self, z, ys):
|
||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||
x, ldj, pys, py = \
|
||||
self.flow(z, ldj, pys=[], ys=ys, reverse=True)
|
||||
|
||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||
|
||||
x_uint8 = torch.clamp(x, min=0, max=255).to(
|
||||
torch.uint8)
|
||||
|
||||
return x_uint8
|
||||
|
||||
def sample(self, n):
|
||||
z_sample = self.prior.sample(n)
|
||||
|
||||
ldj = torch.zeros_like(z_sample[:, 0, 0, 0])
|
||||
x_sample, ldj, pys, py = \
|
||||
self.flow(z_sample, ldj, pys=[], ys=[], reverse=True)
|
||||
|
||||
x_sample, ldj = self.normalize(x_sample, ldj, reverse=True)
|
||||
|
||||
x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to(
|
||||
torch.uint8)
|
||||
|
||||
return x_sample_uint8
|
||||
|
||||
def encode(self, x):
|
||||
batchsize = x.size(0)
|
||||
_, _, _, pz, z, pys, ys, _ = self.forward(x)
|
||||
|
||||
pjs = list(pys) + [pz]
|
||||
js = list(ys) + [z]
|
||||
|
||||
states = []
|
||||
|
||||
for b in range(batchsize):
|
||||
state = None
|
||||
for pj, j in zip(pjs, js):
|
||||
pj_b = [param[b:b+1] for param in pj]
|
||||
j_b = j[b:b+1]
|
||||
|
||||
state = encode_sample(
|
||||
j_b, pj_b, self.variable_type,
|
||||
self.distribution_type, state=state)
|
||||
if state is None:
|
||||
break
|
||||
|
||||
states.append(state)
|
||||
|
||||
return states
|
||||
|
||||
def decode(self, states):
|
||||
def decode_fn(states, pj):
|
||||
states = list(states)
|
||||
j = []
|
||||
|
||||
for b in range(len(states)):
|
||||
pj_b = [param[b:b+1] for param in pj]
|
||||
|
||||
states[b], j_b = decode_sample(
|
||||
states[b], pj_b, self.variable_type,
|
||||
self.distribution_type)
|
||||
|
||||
j.append(j_b)
|
||||
|
||||
j = torch.cat(j, dim=0)
|
||||
return states, j
|
||||
|
||||
states, z = self.prior.decode(states, decode_fn=decode_fn)
|
||||
|
||||
ldj = torch.zeros_like(z[:, 0, 0, 0])
|
||||
|
||||
x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn)
|
||||
|
||||
x, ldj = self.normalize(x, ldj, reverse=True)
|
||||
x = x.to(dtype=torch.uint8)
|
||||
|
||||
return x
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from models.utils import Base
|
||||
|
||||
|
||||
class RoundStraightThrough(torch.autograd.Function):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
rounded = torch.round(input, out=None)
|
||||
return rounded
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = grad_output.clone()
|
||||
return grad_input
|
||||
|
||||
|
||||
_round_straightthrough = RoundStraightThrough().apply
|
||||
|
||||
|
||||
def _stacked_sigmoid(x, temperature, n_approx=3):
|
||||
|
||||
x_ = x - 0.5
|
||||
rounded = torch.round(x_)
|
||||
x_remainder = x_ - rounded
|
||||
|
||||
size = x_.size()
|
||||
x_remainder = x_remainder.view(size + (1,))
|
||||
|
||||
translation = torch.arange(n_approx) - n_approx // 2
|
||||
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
||||
translation = translation.view([1] * len(size) + [len(translation)])
|
||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||
|
||||
return out + rounded - (n_approx // 2)
|
||||
|
||||
|
||||
class SmoothRound(Base):
|
||||
def __init__(self):
|
||||
self._temperature = None
|
||||
self._n_approx = None
|
||||
super().__init__()
|
||||
self.hard_round = None
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
return self._temperature
|
||||
|
||||
@temperature.setter
|
||||
def temperature(self, value):
|
||||
self._temperature = value
|
||||
|
||||
if self._temperature <= 0.05:
|
||||
self._n_approx = 1
|
||||
elif 0.05 < self._temperature < 0.13:
|
||||
self._n_approx = 3
|
||||
else:
|
||||
self._n_approx = 5
|
||||
|
||||
def forward(self, x):
|
||||
assert self._temperature is not None
|
||||
assert self._n_approx is not None
|
||||
assert self.hard_round is not None
|
||||
|
||||
if self.temperature <= 0.25:
|
||||
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
|
||||
else:
|
||||
h = x
|
||||
|
||||
if self.hard_round:
|
||||
h = _round_straightthrough(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class StochasticRound(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hard_round = None
|
||||
|
||||
def forward(self, x):
|
||||
u = torch.rand_like(x)
|
||||
|
||||
h = x + u - 0.5
|
||||
|
||||
if self.hard_round:
|
||||
h = _round_straightthrough(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class BackRound(Base):
|
||||
|
||||
def __init__(self, args, inverse_bin_width):
|
||||
"""
|
||||
BackRound is an approximation to Round that allows for Backpropagation.
|
||||
|
||||
Approximate the round function using a sum of translated sigmoids.
|
||||
The temperature determines how well the round function is approximated,
|
||||
i.e., a lower temperature corresponds to a better approximation, at
|
||||
the cost of more vanishing gradients.
|
||||
|
||||
BackRound supports the following settings:
|
||||
* By setting hard to True and temperature > 0.25, BackRound
|
||||
reduces to a round function with a straight through gradient
|
||||
estimator
|
||||
* When using 0 < temperature <= 0.25 and hard = True, the
|
||||
output in the forward pass is equivalent to a round function, but the
|
||||
gradient is approximated by the gradient of a sum of sigmoids.
|
||||
* When using hard = False, the output is not constrained to integers.
|
||||
* When temperature > 0.25 and hard = False, BackRound reduces to
|
||||
the identity function.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
temperature: float
|
||||
Temperature used for stacked sigmoid approximated. If temperature
|
||||
is greater than 0.25, the approximation reduces to the indentiy
|
||||
function.
|
||||
hard: bool
|
||||
If hard is True, a (hard) round is applied before returning. The
|
||||
gradient for this is approximated using the straight-through
|
||||
estimator.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inverse_bin_width = inverse_bin_width
|
||||
self.round_approx = args.round_approx
|
||||
|
||||
if args.round_approx == 'smooth':
|
||||
self.round = SmoothRound()
|
||||
elif args.round_approx == 'stochastic':
|
||||
self.round = StochasticRound()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def forward(self, x):
|
||||
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
|
||||
h = x * self.inverse_bin_width
|
||||
|
||||
h = self.round(h)
|
||||
|
||||
return h / self.inverse_bin_width
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .backround import BackRound
|
||||
from .networks import NN
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
class SplitFactorCoupling(Base):
|
||||
def __init__(self, c_in, factor, height, width, args):
|
||||
super().__init__()
|
||||
self.n_channels = args.n_channels
|
||||
self.kernel = 3
|
||||
self.input_channel = c_in
|
||||
self.round_approx = args.round_approx
|
||||
|
||||
if args.variable_type == 'discrete':
|
||||
self.round = BackRound(
|
||||
args, inverse_bin_width=2**args.n_bits)
|
||||
else:
|
||||
self.round = None
|
||||
|
||||
self.split_idx = c_in - (c_in // factor)
|
||||
|
||||
self.nn = NN(
|
||||
args=args,
|
||||
c_in=self.split_idx,
|
||||
c_out=c_in - self.split_idx,
|
||||
height=height,
|
||||
width=width,
|
||||
kernel=self.kernel,
|
||||
nn_type=args.coupling_type)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
z1 = z[:, :self.split_idx, :, :]
|
||||
z2 = z[:, self.split_idx:, :, :]
|
||||
|
||||
t = self.nn(z1)
|
||||
|
||||
if self.round is not None:
|
||||
t = self.round(t)
|
||||
|
||||
if not reverse:
|
||||
z2 = z2 + t
|
||||
else:
|
||||
z2 = z2 - t
|
||||
|
||||
z = torch.cat([z1, z2], dim=1)
|
||||
|
||||
return z, ldj
|
||||
|
||||
|
||||
class Coupling(Base):
|
||||
def __init__(self, c_in, height, width, args):
|
||||
super().__init__()
|
||||
|
||||
if args.split_quarter:
|
||||
factor = 4
|
||||
elif args.splitfactor > 1:
|
||||
factor = args.splitfactor
|
||||
else:
|
||||
factor = 2
|
||||
|
||||
self.coupling = SplitFactorCoupling(
|
||||
c_in, factor, height, width, args=args)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
return self.coupling(z, ldj, reverse)
|
||||
|
||||
|
||||
def test_generative_flow():
|
||||
import models.networks as networks
|
||||
global UNIT_TESTING
|
||||
|
||||
networks.UNIT_TESTING = True
|
||||
UNIT_TESTING = True
|
||||
|
||||
batch_size = 17
|
||||
|
||||
input_size = [12, 16, 16]
|
||||
|
||||
class Args():
|
||||
def __init__(self):
|
||||
self.input_size = input_size
|
||||
self.learn_split = False
|
||||
self.variable_type = 'continuous'
|
||||
self.distribution_type = 'logistic'
|
||||
self.round_approx = 'smooth'
|
||||
self.coupling_type = 'shallow'
|
||||
self.conv_type = 'standard'
|
||||
self.densenet_depth = 8
|
||||
self.bottleneck = False
|
||||
self.n_channels = 512
|
||||
self.network1x1 = 'standard'
|
||||
self.auxilary_freq = -1
|
||||
self.actnorm = False
|
||||
self.LU = False
|
||||
self.coupling_lifting_L = True
|
||||
self.splitprior = True
|
||||
self.split_quarter = True
|
||||
self.n_levels = 2
|
||||
self.n_flows = 2
|
||||
self.cond_L = True
|
||||
self.n_bits = True
|
||||
|
||||
args = Args()
|
||||
|
||||
x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256.
|
||||
ldj = torch.zeros_like(x[:, 0, 0, 0])
|
||||
|
||||
model = Coupling(c_in=12, height=16, width=16, args=args)
|
||||
|
||||
print(model)
|
||||
|
||||
model.set_temperature(1.)
|
||||
model.enable_hard_round()
|
||||
|
||||
model.eval()
|
||||
|
||||
z, ldj = model(x, ldj, reverse=False)
|
||||
|
||||
# Check if gradient computation works
|
||||
loss = torch.sum(z**2)
|
||||
loss.backward()
|
||||
|
||||
recon, ldj = model(z, ldj, reverse=True)
|
||||
|
||||
sse = torch.sum(torch.pow(x - recon, 2)).item()
|
||||
ae = torch.abs(x - recon).sum()
|
||||
print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_generative_flow()
|
||||
|
|
@ -1,175 +0,0 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from models.utils import Base
|
||||
from .priors import SplitPrior
|
||||
from .coupling import Coupling
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
def space_to_depth(x):
|
||||
xs = x.size()
|
||||
# Pick off every second element
|
||||
x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2)
|
||||
# Transpose picked elements next to channels.
|
||||
x = x.permute((0, 1, 3, 5, 2, 4)).contiguous()
|
||||
# Combine with channels.
|
||||
x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2)
|
||||
return x
|
||||
|
||||
|
||||
def depth_to_space(x):
|
||||
xs = x.size()
|
||||
# Pick off elements from channels
|
||||
x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3])
|
||||
# Transpose picked elements next to HW dimensions.
|
||||
x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
|
||||
# Combine with HW dimensions.
|
||||
x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2)
|
||||
return x
|
||||
|
||||
|
||||
def int_shape(x):
|
||||
return list(map(int, x.size()))
|
||||
|
||||
|
||||
class Flatten(Base):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class Reshape(Base):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.shape)
|
||||
|
||||
|
||||
class Reverse(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, reverse=False):
|
||||
flip_idx = torch.arange(z.size(1) - 1, -1, -1).long()
|
||||
z = z[:, flip_idx, :, :]
|
||||
return z
|
||||
|
||||
|
||||
class Permute(Base):
|
||||
def __init__(self, n_channels):
|
||||
super().__init__()
|
||||
|
||||
permutation = np.arange(n_channels, dtype='int')
|
||||
np.random.shuffle(permutation)
|
||||
|
||||
permutation_inv = np.zeros(n_channels, dtype='int')
|
||||
permutation_inv[permutation] = np.arange(n_channels, dtype='int')
|
||||
|
||||
self.permutation = torch.from_numpy(permutation)
|
||||
self.permutation_inv = torch.from_numpy(permutation_inv)
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = z[:, self.permutation, :, :]
|
||||
else:
|
||||
z = z[:, self.permutation_inv, :, :]
|
||||
|
||||
return z, ldj
|
||||
|
||||
def InversePermute(self):
|
||||
inv_permute = Permute(len(self.permutation))
|
||||
inv_permute.permutation = self.permutation_inv
|
||||
inv_permute.permutation_inv = self.permutation
|
||||
return inv_permute
|
||||
|
||||
|
||||
class Squeeze(Base):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z, ldj, reverse=False):
|
||||
if not reverse:
|
||||
z = space_to_depth(z)
|
||||
else:
|
||||
z = depth_to_space(z)
|
||||
return z, ldj
|
||||
|
||||
|
||||
class GenerativeFlow(Base):
|
||||
def __init__(self, n_channels, height, width, args):
|
||||
super().__init__()
|
||||
layers = []
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
for level in range(args.n_levels):
|
||||
|
||||
for i in range(args.n_flows):
|
||||
perm_layer = Permute(n_channels)
|
||||
layers.append(perm_layer)
|
||||
|
||||
layers.append(
|
||||
Coupling(n_channels, height, width, args))
|
||||
|
||||
if level < args.n_levels - 1:
|
||||
if args.splitprior_type != 'none':
|
||||
# Standard splitprior
|
||||
factor_out = n_channels // 2
|
||||
layers.append(SplitPrior(n_channels, factor_out, height, width, args))
|
||||
n_channels = n_channels - factor_out
|
||||
|
||||
layers.append(Squeeze())
|
||||
n_channels *= 4
|
||||
height //= 2
|
||||
width //= 2
|
||||
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
self.z_size = (n_channels, height, width)
|
||||
|
||||
def forward(self, z, ldj, pys=(), ys=(), reverse=False):
|
||||
if not reverse:
|
||||
for l, layer in enumerate(self.layers):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
py, y, z, ldj = layer(z, ldj)
|
||||
pys += (py,)
|
||||
ys += (y,)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj)
|
||||
|
||||
else:
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, (SplitPrior)):
|
||||
if len(ys) > 0:
|
||||
z, ldj = layer.inverse(z, ldj, y=ys[-1])
|
||||
# Pop last element
|
||||
ys = ys[:-1]
|
||||
else:
|
||||
z, ldj = layer.inverse(z, ldj, y=None)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj, pys, ys
|
||||
|
||||
def decode(self, z, ldj, state, decode_fn):
|
||||
|
||||
for l, layer in reversed(list(enumerate(self.layers))):
|
||||
if isinstance(layer, SplitPrior):
|
||||
z, ldj, state = layer.decode(z, ldj, state, decode_fn)
|
||||
|
||||
else:
|
||||
z, ldj = layer(z, ldj, reverse=True)
|
||||
|
||||
return z, ldj
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.utils import Base
|
||||
|
||||
|
||||
UNIT_TESTING = False
|
||||
|
||||
|
||||
class Conv2dReLU(Base):
|
||||
def __init__(
|
||||
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
|
||||
y = F.relu(h)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ResidualBlock(Base):
|
||||
def __init__(self, n_channels, kernel, Conv2dAct):
|
||||
super().__init__()
|
||||
|
||||
self.nn = torch.nn.Sequential(
|
||||
Conv2dAct(n_channels, n_channels, kernel, padding=1),
|
||||
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
h = F.relu(h + x)
|
||||
return h
|
||||
|
||||
|
||||
class DenseLayer(Base):
|
||||
def __init__(self, args, n_inputs, growth, Conv2dAct):
|
||||
super().__init__()
|
||||
|
||||
conv1x1 = Conv2dAct(
|
||||
n_inputs, n_inputs, kernel_size=1, stride=1,
|
||||
padding=0, bias=True)
|
||||
|
||||
self.nn = torch.nn.Sequential(
|
||||
conv1x1,
|
||||
Conv2dAct(
|
||||
n_inputs, growth, kernel_size=3, stride=1,
|
||||
padding=1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.nn(x)
|
||||
|
||||
h = torch.cat([x, h], dim=1)
|
||||
return h
|
||||
|
||||
|
||||
class DenseBlock(Base):
|
||||
def __init__(
|
||||
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
|
||||
super().__init__()
|
||||
depth = args.densenet_depth
|
||||
|
||||
future_growth = n_outputs - n_inputs
|
||||
|
||||
layers = []
|
||||
|
||||
for d in range(depth):
|
||||
growth = future_growth // (depth - d)
|
||||
|
||||
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
|
||||
n_inputs += growth
|
||||
future_growth -= growth
|
||||
|
||||
self.nn = torch.nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.nn(x)
|
||||
|
||||
|
||||
class Identity(Base):
|
||||
def __init__(self):
|
||||
super.__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class NN(Base):
|
||||
def __init__(
|
||||
self, args, c_in, c_out, height, width, nn_type, kernel=3):
|
||||
super().__init__()
|
||||
|
||||
Conv2dAct = Conv2dReLU
|
||||
n_channels = args.n_channels
|
||||
|
||||
if nn_type == 'shallow':
|
||||
|
||||
if args.network1x1 == 'standard':
|
||||
conv1x1 = Conv2dAct(
|
||||
n_channels, n_channels, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
|
||||
layers = [
|
||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||
conv1x1]
|
||||
|
||||
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
|
||||
|
||||
elif nn_type == 'resnet':
|
||||
layers = [
|
||||
Conv2dAct(c_in, n_channels, kernel, padding=1),
|
||||
ResidualBlock(n_channels, kernel, Conv2dAct),
|
||||
ResidualBlock(n_channels, kernel, Conv2dAct)]
|
||||
|
||||
layers += [
|
||||
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
|
||||
]
|
||||
|
||||
elif nn_type == 'densenet':
|
||||
layers = [
|
||||
DenseBlock(
|
||||
args=args,
|
||||
n_inputs=c_in,
|
||||
n_outputs=n_channels + c_in,
|
||||
kernel=kernel,
|
||||
Conv2dAct=Conv2dAct)]
|
||||
|
||||
layers += [
|
||||
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
|
||||
]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.nn = torch.nn.Sequential(*layers)
|
||||
|
||||
# Set parameters of last conv-layer to zero.
|
||||
if not UNIT_TESTING:
|
||||
self.nn[-1].weight.data.zero_()
|
||||
self.nn[-1].bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
return self.nn(x)
|
||||
|
|
@ -1,164 +0,0 @@
|
|||
"""
|
||||
Collection of flow strategies
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
from utils.distributions import sample_discretized_logistic, \
|
||||
sample_mixture_discretized_logistic, sample_normal, sample_logistic, \
|
||||
sample_discretized_normal, sample_mixture_normal
|
||||
from models.utils import Base
|
||||
from .networks import NN
|
||||
|
||||
|
||||
def sample_prior(px, variable_type, distribution_type, inverse_bin_width):
|
||||
if variable_type == 'discrete':
|
||||
if distribution_type == 'logistic':
|
||||
if len(px) == 2:
|
||||
return sample_discretized_logistic(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
elif len(px) == 3:
|
||||
return sample_mixture_discretized_logistic(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
elif distribution_type == 'normal':
|
||||
return sample_discretized_normal(
|
||||
*px, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
elif variable_type == 'continuous':
|
||||
if distribution_type == 'logistic':
|
||||
return sample_logistic(*px)
|
||||
elif distribution_type == 'normal':
|
||||
if len(px) == 2:
|
||||
return sample_normal(*px)
|
||||
elif len(px) == 3:
|
||||
return sample_mixture_normal(*px)
|
||||
elif distribution_type == 'steplogistic':
|
||||
return sample_logistic(*px)
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
class Prior(Base):
|
||||
def __init__(self, size, args):
|
||||
super().__init__()
|
||||
c, h, w = size
|
||||
|
||||
self.inverse_bin_width = 2**args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
self.n_mixtures = args.n_mixtures
|
||||
|
||||
if self.n_mixtures == 1:
|
||||
self.mu = Parameter(torch.Tensor(c, h, w))
|
||||
self.logs = Parameter(torch.Tensor(c, h, w))
|
||||
elif self.n_mixtures > 1:
|
||||
self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.mu.data.zero_()
|
||||
|
||||
if self.n_mixtures > 1:
|
||||
self.pi_logit.data.zero_()
|
||||
for i in range(self.n_mixtures):
|
||||
self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2.
|
||||
|
||||
self.logs.data.zero_()
|
||||
|
||||
def get_pz(self, n):
|
||||
if self.n_mixtures == 1:
|
||||
mu = self.mu.repeat(n, 1, 1, 1)
|
||||
logs = self.logs.repeat(n, 1, 1, 1) # scaling scale
|
||||
return mu, logs
|
||||
|
||||
elif self.n_mixtures > 1:
|
||||
pi = F.softmax(self.pi_logit, dim=-1)
|
||||
mu = self.mu.repeat(n, 1, 1, 1, 1)
|
||||
logs = self.logs.repeat(n, 1, 1, 1, 1)
|
||||
pi = pi.repeat(n, 1, 1, 1, 1)
|
||||
return mu, logs, pi
|
||||
|
||||
def forward(self, z, ldj):
|
||||
pz = self.get_pz(z.size(0))
|
||||
|
||||
return pz, z, ldj
|
||||
|
||||
def sample(self, n):
|
||||
pz = self.get_pz(n)
|
||||
|
||||
z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||
|
||||
return z_sample
|
||||
|
||||
def decode(self, states, decode_fn):
|
||||
pz = self.get_pz(n=len(states))
|
||||
|
||||
states, z = decode_fn(states, pz)
|
||||
return states, z
|
||||
|
||||
|
||||
class SplitPrior(Base):
|
||||
def __init__(self, c_in, factor_out, height, width, args):
|
||||
super().__init__()
|
||||
|
||||
self.split_idx = c_in - factor_out
|
||||
self.inverse_bin_width = 2**args.n_bits
|
||||
self.variable_type = args.variable_type
|
||||
self.distribution_type = args.distribution_type
|
||||
self.input_channel = c_in
|
||||
|
||||
self.nn = NN(
|
||||
args=args,
|
||||
c_in=c_in - factor_out,
|
||||
c_out=factor_out * 2,
|
||||
height=height,
|
||||
width=width,
|
||||
nn_type=args.splitprior_type)
|
||||
|
||||
def get_py(self, z):
|
||||
h = self.nn(z)
|
||||
mu = h[:, ::2, :, :]
|
||||
logs = h[:, 1::2, :, :]
|
||||
|
||||
py = [mu, logs]
|
||||
|
||||
return py
|
||||
|
||||
def split(self, z):
|
||||
z1 = z[:, :self.split_idx, :, :]
|
||||
y = z[:, self.split_idx:, :, :]
|
||||
return z1, y
|
||||
|
||||
def combine(self, z, y):
|
||||
result = torch.cat([z, y], dim=1)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, z, ldj):
|
||||
z, y = self.split(z)
|
||||
|
||||
py = self.get_py(z)
|
||||
|
||||
return py, y, z, ldj
|
||||
|
||||
def inverse(self, z, ldj, y):
|
||||
# Sample if y is not given.
|
||||
if y is None:
|
||||
py = self.get_py(z)
|
||||
y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width)
|
||||
|
||||
z = self.combine(z, y)
|
||||
|
||||
return z, ldj
|
||||
|
||||
def decode(self, z, ldj, states, decode_fn):
|
||||
py = self.get_py(z)
|
||||
states, y = decode_fn(states, py)
|
||||
return self.combine(z, y), ldj, states
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
class Base(torch.nn.Module):
|
||||
"""
|
||||
The base class for modules. That contains a disable round mode
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _set_child_attribute(self, attr, value):
|
||||
r"""Sets the module in rounding mode.
|
||||
|
||||
This has any effect only on certain modules if variable type is
|
||||
discrete.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
if hasattr(self, attr):
|
||||
setattr(self, attr, value)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, attr):
|
||||
setattr(module, attr, value)
|
||||
return self
|
||||
|
||||
def set_temperature(self, value):
|
||||
self._set_child_attribute("temperature", value)
|
||||
|
||||
def enable_hard_round(self, mode=True):
|
||||
self._set_child_attribute("hard_round", mode)
|
||||
|
||||
def disable_hard_round(self, mode=True):
|
||||
self.enable_hard_round(not mode)
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils.distributions import log_discretized_logistic, \
|
||||
log_mixture_discretized_logistic, log_normal, log_discretized_normal, \
|
||||
log_logistic, log_mixture_normal
|
||||
from models.backround import _round_straightthrough
|
||||
|
||||
|
||||
def compute_log_ps(pxs, xs, args):
|
||||
# Add likelihoods of intermediate representations.
|
||||
inverse_bin_width = 2.**args.n_bits
|
||||
|
||||
log_pxs = []
|
||||
for px, x in zip(pxs, xs):
|
||||
|
||||
if args.variable_type == 'discrete':
|
||||
if args.distribution_type == 'logistic':
|
||||
log_px = log_discretized_logistic(
|
||||
x, *px, inverse_bin_width=inverse_bin_width)
|
||||
elif args.distribution_type == 'normal':
|
||||
log_px = log_discretized_normal(
|
||||
x, *px, inverse_bin_width=inverse_bin_width)
|
||||
elif args.variable_type == 'continuous':
|
||||
if args.distribution_type == 'logistic':
|
||||
log_px = log_logistic(x, *px)
|
||||
elif args.distribution_type == 'normal':
|
||||
log_px = log_normal(x, *px)
|
||||
elif args.distribution_type == 'steplogistic':
|
||||
x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width
|
||||
log_px = log_discretized_logistic(
|
||||
x, *px, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
log_pxs.append(
|
||||
torch.sum(log_px, dim=[1, 2, 3]))
|
||||
|
||||
return log_pxs
|
||||
|
||||
|
||||
def compute_log_pz(pz, z, args):
|
||||
inverse_bin_width = 2.**args.n_bits
|
||||
|
||||
if args.variable_type == 'discrete':
|
||||
if args.distribution_type == 'logistic':
|
||||
if args.n_mixtures == 1:
|
||||
log_pz = log_discretized_logistic(
|
||||
z, pz[0], pz[1], inverse_bin_width=inverse_bin_width)
|
||||
else:
|
||||
log_pz = log_mixture_discretized_logistic(
|
||||
z, pz[0], pz[1], pz[2],
|
||||
inverse_bin_width=inverse_bin_width)
|
||||
elif args.distribution_type == 'normal':
|
||||
log_pz = log_discretized_normal(
|
||||
z, *pz, inverse_bin_width=inverse_bin_width)
|
||||
|
||||
elif args.variable_type == 'continuous':
|
||||
if args.distribution_type == 'logistic':
|
||||
log_pz = log_logistic(z, *pz)
|
||||
elif args.distribution_type == 'normal':
|
||||
if args.n_mixtures == 1:
|
||||
log_pz = log_normal(z, *pz)
|
||||
else:
|
||||
log_pz = log_mixture_normal(z, *pz)
|
||||
elif args.distribution_type == 'steplogistic':
|
||||
z = _round_straightthrough(z * 256.) / 256.
|
||||
log_pz = log_discretized_logistic(z, *pz)
|
||||
|
||||
log_pz = torch.sum(
|
||||
log_pz,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
return log_pz
|
||||
|
||||
|
||||
def compute_loss_function(pz, z, pys, ys, ldj, args):
|
||||
"""
|
||||
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
||||
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
||||
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
||||
:param z_mu: mean of z_0
|
||||
:param z_var: variance of z_0
|
||||
:param z_0: first stochastic latent variable
|
||||
:param z_k: last stochastic latent variable
|
||||
:param ldj: log det jacobian
|
||||
:param args: global parameter settings
|
||||
:param beta: beta for kl loss
|
||||
:return: loss, ce, kl
|
||||
"""
|
||||
batch_size = z.size(0)
|
||||
|
||||
# Get array loss, sum over batch
|
||||
loss_array, bpd_array, bpd_per_prior_array = \
|
||||
compute_loss_array(pz, z, pys, ys, ldj, args)
|
||||
|
||||
loss = torch.mean(loss_array)
|
||||
bpd = torch.mean(bpd_array).item()
|
||||
bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array]
|
||||
|
||||
return loss, bpd, bpd_per_prior
|
||||
|
||||
|
||||
def convert_bpd(log_p, input_size):
|
||||
return -log_p / (np.prod(input_size) * np.log(2.))
|
||||
|
||||
|
||||
def compute_loss_array(pz, z, pys, ys, ldj, args):
|
||||
"""
|
||||
Computes the cross entropy loss function while summing over batch dimension, not averaged!
|
||||
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
|
||||
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
|
||||
:param z_mu: mean of z_0
|
||||
:param z_var: variance of z_0
|
||||
:param z_0: first stochastic latent variable
|
||||
:param z_k: last stochastic latent variable
|
||||
:param ldj: log det jacobian
|
||||
:param args: global parameter settings
|
||||
:param beta: beta for kl loss
|
||||
:return: loss, ce, kl
|
||||
"""
|
||||
bpd_per_prior = []
|
||||
|
||||
# Likelihood of final representation.
|
||||
log_pz = compute_log_pz(pz, z, args)
|
||||
|
||||
bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size))
|
||||
|
||||
log_p = log_pz
|
||||
|
||||
# Add likelihoods of intermediate representations.
|
||||
if ys:
|
||||
log_pys = compute_log_ps(pys, ys, args)
|
||||
|
||||
for log_py in log_pys:
|
||||
log_p += log_py
|
||||
|
||||
bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size))
|
||||
|
||||
log_p += ldj
|
||||
|
||||
loss = -log_p
|
||||
bpd = convert_bpd(log_p.detach(), args.input_size)
|
||||
|
||||
return loss, bpd, bpd_per_prior
|
||||
|
||||
|
||||
def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args):
|
||||
return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import torch
|
||||
|
||||
from optimization.loss import calculate_loss
|
||||
from utils.visual_evaluation import plot_reconstructions
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def train(epoch, train_loader, model, opt, args):
|
||||
model.train()
|
||||
train_loss = np.zeros(len(train_loader))
|
||||
train_bpd = np.zeros(len(train_loader))
|
||||
|
||||
num_data = 0
|
||||
|
||||
for batch_idx, (data, _) in enumerate(train_loader):
|
||||
data = data.view(-1, *args.input_size)
|
||||
|
||||
data = data.to(args.DEVICE)
|
||||
|
||||
opt.zero_grad()
|
||||
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
||||
|
||||
loss = torch.mean(loss)
|
||||
bpd = torch.mean(bpd)
|
||||
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
|
||||
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
train_loss[batch_idx] = loss
|
||||
train_bpd[batch_idx] = bpd
|
||||
|
||||
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
|
||||
|
||||
opt.step()
|
||||
|
||||
num_data += len(data)
|
||||
|
||||
if batch_idx % args.log_interval == 0:
|
||||
perc = 100. * batch_idx / len(train_loader)
|
||||
|
||||
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
|
||||
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
|
||||
|
||||
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
|
||||
|
||||
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
|
||||
for i in range(1, len(bpd_per_prior)):
|
||||
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
|
||||
|
||||
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
if len(pz) == 3:
|
||||
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
|
||||
for i, py in enumerate(pys):
|
||||
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
|
||||
from utils.visual_evaluation import plot_images
|
||||
import os
|
||||
if not os.path.exists(args.snap_dir + 'training/'):
|
||||
os.makedirs(args.snap_dir + 'training/')
|
||||
|
||||
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
|
||||
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
|
||||
|
||||
return train_loss, train_bpd
|
||||
|
||||
|
||||
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
|
||||
model.eval()
|
||||
|
||||
loss_type = 'bpd'
|
||||
|
||||
def analyse(data_loader, plot=False):
|
||||
bpds = []
|
||||
batch_idx = 0
|
||||
with torch.no_grad():
|
||||
for data, _ in data_loader:
|
||||
batch_idx += 1
|
||||
|
||||
if args.cuda:
|
||||
data = data.cuda()
|
||||
|
||||
data = data.view(-1, *args.input_size)
|
||||
|
||||
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
|
||||
model(data)
|
||||
loss = torch.mean(loss).item()
|
||||
batch_bpd = torch.mean(batch_bpd).item()
|
||||
|
||||
bpds.append(batch_bpd)
|
||||
|
||||
bpd = np.mean(bpds)
|
||||
|
||||
with torch.no_grad():
|
||||
if not testing and plot:
|
||||
x_sample = model_sample.sample(n=100)
|
||||
|
||||
try:
|
||||
plot_reconstructions(
|
||||
x_sample, bpd, loss_type, epoch, args)
|
||||
except:
|
||||
print('Not plotting')
|
||||
|
||||
return bpd
|
||||
|
||||
bpd_train = analyse(train_loader)
|
||||
bpd_val = analyse(val_loader, plot=True)
|
||||
|
||||
with open(file, 'a') as ff:
|
||||
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
|
||||
epoch,
|
||||
bpd_train,
|
||||
bpd_val)
|
||||
print(msg, file=ff)
|
||||
|
||||
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
|
||||
bpd = bpd_val
|
||||
|
||||
file = None
|
||||
|
||||
# Compute log-likelihood
|
||||
with torch.no_grad():
|
||||
if testing:
|
||||
test_data = val_loader.dataset.data_tensor
|
||||
|
||||
if args.cuda:
|
||||
test_data = test_data.cuda()
|
||||
|
||||
print('Computing log-likelihood on test set')
|
||||
|
||||
model.eval()
|
||||
|
||||
log_likelihood = analyse(test_data)
|
||||
|
||||
else:
|
||||
log_likelihood = None
|
||||
nll_bpd = None
|
||||
|
||||
if file is None:
|
||||
if testing:
|
||||
print('====> Test set loss: {:.4f}'.format(loss))
|
||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
|
||||
|
||||
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
|
||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
|
||||
(np.prod(args.input_size) * np.log(2.))))
|
||||
|
||||
else:
|
||||
print('====> Validation set loss: {:.4f}'.format(loss))
|
||||
print('====> Validation set bpd: {:.4f}'.format(bpd))
|
||||
else:
|
||||
with open(file, 'a') as ff:
|
||||
if testing:
|
||||
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
|
||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
|
||||
|
||||
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
|
||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
|
||||
(np.prod(args.input_size) * np.log(2.))),
|
||||
file=ff)
|
||||
|
||||
else:
|
||||
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
|
||||
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
|
||||
file=ff)
|
||||
|
||||
if not testing:
|
||||
return loss, bpd
|
||||
else:
|
||||
return log_likelihood, nll_bpd
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
MIN_EPSILON = 1e-5
|
||||
MAX_EPSILON = 1.-1e-5
|
||||
|
||||
|
||||
PI = math.pi
|
||||
|
||||
|
||||
def log_min_exp(a, b, epsilon=1e-8):
|
||||
"""
|
||||
Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.
|
||||
Using:
|
||||
log(exp(a) - exp(b))
|
||||
c + log(exp(a-c) - exp(b-c))
|
||||
a + log(1 - exp(b-a))
|
||||
And note that we assume b < a always.
|
||||
"""
|
||||
y = a + torch.log(1 - torch.exp(b - a) + epsilon)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def log_normal(x, mean, logvar):
|
||||
logp = -0.5 * logvar
|
||||
logp += -0.5 * np.log(2 * PI)
|
||||
logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar)
|
||||
return logp
|
||||
|
||||
|
||||
def log_mixture_normal(x, mean, logvar, pi):
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
logp_mixtures = log_normal(x, mean, logvar)
|
||||
|
||||
logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_normal(mean, logvar):
|
||||
y = torch.randn_like(mean)
|
||||
|
||||
x = torch.exp(0.5 * logvar) * y + mean
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sample_mixture_normal(mean, logvar, pi):
|
||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||
pi = pi.view(b * c * h * w, n_mixtures)
|
||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||
|
||||
# Select mixture params
|
||||
mean = mean.view(b * c * h * w, n_mixtures)
|
||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
logvar = logvar.view(b * c * h * w, n_mixtures)
|
||||
logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
|
||||
y = sample_normal(mean, logvar)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def log_logistic(x, mean, logscale):
|
||||
"""
|
||||
pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale
|
||||
"""
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
u = (x - mean) / scale
|
||||
|
||||
logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_logistic(mean, logscale):
|
||||
y = torch.rand_like(mean)
|
||||
|
||||
x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
logp = log_min_exp(
|
||||
F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
|
||||
F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
def sample_discretized_logistic(mean, logscale, inverse_bin_width):
|
||||
x = sample_logistic(mean, logscale)
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
return x
|
||||
|
||||
|
||||
def normal_cdf(value, loc, std):
|
||||
return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2)))
|
||||
|
||||
|
||||
def log_discretized_normal(x, mean, logvar, inverse_bin_width):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7)
|
||||
|
||||
return log_p
|
||||
|
||||
|
||||
def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std)
|
||||
|
||||
p = torch.sum(p * pi, dim=-1)
|
||||
|
||||
logp = torch.log(p + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_discretized_normal(mean, logvar, inverse_bin_width):
|
||||
y = torch.randn_like(mean)
|
||||
|
||||
x = torch.exp(0.5 * logvar) * y + mean
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \
|
||||
- torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
p = torch.sum(p * pi, dim=-1)
|
||||
|
||||
logp = torch.log(p + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
x = x[..., None]
|
||||
|
||||
cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
cdf = torch.sum(cdfs * pi, dim=-1)
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width):
|
||||
# Sample mixtures
|
||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||
pi = pi.view(b * c * h * w, n_mixtures)
|
||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||
|
||||
# Select mixture params
|
||||
mean = mean.view(b * c * h * w, n_mixtures)
|
||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
logs = logs.view(b * c * h * w, n_mixtures)
|
||||
logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
|
||||
y = torch.rand_like(mean)
|
||||
x = torch.exp(logs) * torch.log(y / (1 - y)) + mean
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_multinomial(logits, targets):
|
||||
return -F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
|
||||
def sample_multinomial(logits):
|
||||
b, n_categories, c, h, w = logits.size()
|
||||
logits = logits.permute(0, 2, 3, 4, 1)
|
||||
p = F.softmax(logits, dim=-1)
|
||||
p = p.view(b * c * h * w, n_categories)
|
||||
x = torch.multinomial(p, num_samples=1).view(b, c, h, w)
|
||||
return x
|
||||
|
|
@ -1,264 +0,0 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data_utils
|
||||
import pickle
|
||||
from scipy.io import loadmat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as vf
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
import os.path
|
||||
from os.path import join
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
|
||||
class ToTensorNoNorm():
|
||||
def __call__(self, X_i):
|
||||
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
||||
|
||||
|
||||
class PadToMultiple(object):
|
||||
def __init__(self, multiple, fill=0, padding_mode='constant'):
|
||||
assert isinstance(multiple, numbers.Number)
|
||||
assert isinstance(fill, (numbers.Number, str, tuple))
|
||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
||||
|
||||
self.multiple = multiple
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be padded.
|
||||
Returns:
|
||||
PIL Image: Padded image.
|
||||
"""
|
||||
w, h = img.size
|
||||
m = self.multiple
|
||||
nw = (w // m + int((w % m) != 0)) * m
|
||||
nh = (h // m + int((h % m) != 0)) * m
|
||||
padw = nw - w
|
||||
padh = nh - h
|
||||
|
||||
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
|
||||
format(self.mulitple, self.fill, self.padding_mode)
|
||||
|
||||
|
||||
class CustomTensorDataset(Dataset):
|
||||
"""Dataset wrapping tensors.
|
||||
|
||||
Each sample will be retrieved by indexing tensors along the first dimension.
|
||||
|
||||
Arguments:
|
||||
*tensors (Tensor): tensors that have the same size of the first dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, *tensors, transform=None):
|
||||
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
||||
self.tensors = tensors
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
from PIL import Image
|
||||
|
||||
X, y = self.tensors
|
||||
X_i, y_i, = X[index], y[index]
|
||||
|
||||
if self.transform:
|
||||
X_i = self.transform(X_i)
|
||||
X_i = torch.from_numpy(np.array(X_i, copy=False))
|
||||
X_i = X_i.permute(2, 0, 1)
|
||||
|
||||
return X_i, y_i
|
||||
|
||||
def __len__(self):
|
||||
return self.tensors[0].size(0)
|
||||
|
||||
|
||||
def load_cifar10(args, **kwargs):
|
||||
# set args
|
||||
args.input_size = [3, 32, 32]
|
||||
args.input_type = 'continuous'
|
||||
args.dynamic_binarization = False
|
||||
|
||||
from keras.datasets import cifar10
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
||||
x_train = x_train.transpose(0, 3, 1, 2)
|
||||
x_test = x_test.transpose(0, 3, 1, 2)
|
||||
|
||||
import math
|
||||
|
||||
if args.data_augmentation_level == 2:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
|
||||
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
||||
transforms.CenterCrop(32)
|
||||
])
|
||||
elif args.data_augmentation_level == 1:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
else:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
])
|
||||
|
||||
x_val = x_train[-10000:]
|
||||
y_val = y_train[-10000:]
|
||||
|
||||
x_train = x_train[:-10000]
|
||||
y_train = y_train[:-10000]
|
||||
|
||||
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
|
||||
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
|
||||
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
|
||||
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||
|
||||
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
|
||||
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
def extract_tar(tarpath):
|
||||
assert tarpath.endswith('.tar')
|
||||
|
||||
startdir = tarpath[:-4] + '/'
|
||||
|
||||
if os.path.exists(startdir):
|
||||
return startdir
|
||||
|
||||
print('Extracting', tarpath)
|
||||
|
||||
with tarfile.open(name=tarpath) as tar:
|
||||
t = 0
|
||||
done = False
|
||||
while not done:
|
||||
path = join(startdir, 'images{}'.format(t))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
print(path)
|
||||
|
||||
for i in range(50000):
|
||||
member = tar.next()
|
||||
|
||||
if member is None:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Skip directories
|
||||
while member.isdir():
|
||||
member = tar.next()
|
||||
if member is None:
|
||||
done = True
|
||||
break
|
||||
|
||||
member.name = member.name.split('/')[-1]
|
||||
|
||||
tar.extract(member, path=path)
|
||||
|
||||
t += 1
|
||||
|
||||
return startdir
|
||||
|
||||
|
||||
def load_imagenet(resolution, args, **kwargs):
|
||||
assert resolution == 32 or resolution == 64
|
||||
|
||||
args.input_size = [3, resolution, resolution]
|
||||
|
||||
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
|
||||
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
|
||||
|
||||
trainpath = extract_tar(trainpath)
|
||||
valpath = extract_tar(valpath)
|
||||
|
||||
data_transform = transforms.Compose([
|
||||
ToTensorNoNorm()
|
||||
])
|
||||
|
||||
print('Starting loading ImageNet')
|
||||
|
||||
imagenet_data = torchvision.datasets.ImageFolder(
|
||||
trainpath,
|
||||
transform=data_transform)
|
||||
|
||||
print('Number of data images', len(imagenet_data))
|
||||
|
||||
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
|
||||
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
|
||||
|
||||
train_dataset = torch.utils.data.dataset.Subset(
|
||||
imagenet_data, train_idcs)
|
||||
val_dataset = torch.utils.data.dataset.Subset(
|
||||
imagenet_data, val_idcs)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
**kwargs)
|
||||
|
||||
test_dataset = torchvision.datasets.ImageFolder(
|
||||
valpath,
|
||||
transform=data_transform)
|
||||
|
||||
print('Number of val images:', len(test_dataset))
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
**kwargs)
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
def load_dataset(args, **kwargs):
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
|
||||
elif args.dataset == 'imagenet32':
|
||||
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
|
||||
elif args.dataset == 'imagenet64':
|
||||
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
|
||||
else:
|
||||
raise Exception('Wrong name of the dataset!')
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class Args():
|
||||
def __init__(self):
|
||||
self.batch_size = 128
|
||||
train_loader, val_loader, test_loader, args = load_imagenet32(Args())
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from scipy.misc import logsumexp
|
||||
from optimization.loss import calculate_loss_array
|
||||
|
||||
|
||||
def calculate_likelihood(X, model, args, S=5000, MB=500):
|
||||
|
||||
# set auxiliary variables for number of training and test sets
|
||||
N_test = X.size(0)
|
||||
|
||||
X = X.view(-1, *args.input_size)
|
||||
|
||||
likelihood_test = []
|
||||
|
||||
if S <= MB:
|
||||
R = 1
|
||||
else:
|
||||
R = S // MB
|
||||
S = MB
|
||||
|
||||
for j in range(N_test):
|
||||
if j % 100 == 0:
|
||||
print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100))
|
||||
|
||||
x_single = X[j].unsqueeze(0)
|
||||
|
||||
a = []
|
||||
for r in range(0, R):
|
||||
# Repeat it for all training points
|
||||
x = x_single.expand(S, *x_single.size()[1:]).contiguous()
|
||||
|
||||
x_mean, z_mu, z_var, ldj, z0, zk = model(x)
|
||||
|
||||
a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args)
|
||||
|
||||
a.append(-a_tmp.cpu().data.numpy())
|
||||
|
||||
# calculate max
|
||||
a = np.asarray(a)
|
||||
a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
|
||||
likelihood_x = logsumexp(a)
|
||||
likelihood_test.append(likelihood_x - np.log(len(a)))
|
||||
|
||||
likelihood_test = np.array(likelihood_test)
|
||||
|
||||
nll = -np.mean(likelihood_test)
|
||||
|
||||
if args.input_type == 'multinomial':
|
||||
bpd = nll/(np.prod(args.input_size) * np.log(2.))
|
||||
elif args.input_type == 'binary':
|
||||
bpd = 0.
|
||||
else:
|
||||
raise ValueError('invalid input type!')
|
||||
|
||||
return nll, bpd
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
# noninteractive background
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None):
|
||||
"""
|
||||
Plots train_loss and validation loss as a function of optimization iteration
|
||||
:param train_loss: np.array of train_loss (1D or 2D)
|
||||
:param validation_loss: np.array of validation loss (1D or 2D)
|
||||
:param fname: output file name
|
||||
:param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied
|
||||
accross training curves.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
plt.close()
|
||||
|
||||
matplotlib.rcParams.update({'font.size': 14})
|
||||
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||||
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||||
|
||||
if len(train_loss.shape) == 1:
|
||||
# Single training curve
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||
figsize = (6, 4)
|
||||
|
||||
if train_loss.shape[0] == validation_loss.shape[0]:
|
||||
# validation score evaluated every iteration
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||
|
||||
elif train_loss.shape[0] % validation_loss.shape[0] == 0:
|
||||
# validation score evaluated every epoch
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||
|
||||
x = np.arange(validation_loss.shape[0])
|
||||
x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0]
|
||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||
else:
|
||||
raise ValueError('Length of train_loss and validation_loss must be equal or divisible')
|
||||
|
||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||
ax.set_ylim([miny, maxy])
|
||||
|
||||
elif len(train_loss.shape) == 2:
|
||||
# Multiple training curves
|
||||
|
||||
cmap = plt.cm.brg
|
||||
|
||||
cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0])
|
||||
scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
|
||||
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||
figsize = (6, 4)
|
||||
|
||||
if labels is None:
|
||||
labels = ['%d' % i for i in range(train_loss.shape[0])]
|
||||
|
||||
if train_loss.shape[1] == validation_loss.shape[1]:
|
||||
for i in range(train_loss.shape[0]):
|
||||
color_val = scalarMap.to_rgba(i)
|
||||
|
||||
# validation score evaluated every iteration
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||
ax.plot(x, validation_loss[i], '--', lw=2., color=color_val)
|
||||
|
||||
elif train_loss.shape[1] % validation_loss.shape[1] == 0:
|
||||
for i in range(train_loss.shape[0]):
|
||||
color_val = scalarMap.to_rgba(i)
|
||||
|
||||
# validation score evaluated every epoch
|
||||
x = np.arange(train_loss.shape[1])
|
||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||
|
||||
x = np.arange(validation_loss.shape[1])
|
||||
x = (x+1) * train_loss.shape[1] / validation_loss.shape[1]
|
||||
ax.plot(x, validation_loss[i], '-', lw=2., color=color_val)
|
||||
|
||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||
ax.set_ylim([miny, maxy])
|
||||
|
||||
else:
|
||||
raise ValueError('train_loss and validation_loss must be 1D or 2D arrays')
|
||||
|
||||
ax.set_xlabel('iteration')
|
||||
ax.set_ylabel('loss')
|
||||
plt.title('Training and validation loss')
|
||||
|
||||
fig.set_size_inches(figsize)
|
||||
fig.subplots_adjust(hspace=0.1)
|
||||
plt.savefig(fname, bbox_inches='tight')
|
||||
|
||||
plt.close()
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
|
||||
|
||||
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
|
||||
if epoch == 1:
|
||||
if not os.path.exists(args.snap_dir + 'reconstruction/'):
|
||||
os.makedirs(args.snap_dir + 'reconstruction/')
|
||||
if loss_type == 'bpd':
|
||||
fname = str(epoch) + '_bpd_%5.3f' % loss
|
||||
elif loss_type == 'elbo':
|
||||
fname = str(epoch) + '_elbo_%6.4f' % loss
|
||||
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
|
||||
|
||||
|
||||
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
|
||||
batch, channels, height, width = x_sample.shape
|
||||
|
||||
print(x_sample.shape)
|
||||
|
||||
mosaic = np.zeros((height * size_y, width * size_x, channels))
|
||||
|
||||
for j in range(size_y):
|
||||
for i in range(size_x):
|
||||
idx = j * size_x + i
|
||||
|
||||
image = x_sample[idx]
|
||||
|
||||
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
|
||||
image.transpose(1, 2, 0)
|
||||
|
||||
# Remove channel for BW images
|
||||
mosaic = mosaic.squeeze()
|
||||
|
||||
imageio.imwrite(dir + file_name + '.png', mosaic)
|
||||
|
|
@ -89,7 +89,7 @@ def main():
|
|||
|
||||
model = None
|
||||
if args.model_path is not None:
|
||||
print("Loading the model...")
|
||||
print("Loading the models...")
|
||||
model = torch.load(args.model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer()
|
||||
1
models/cnn/__init__.py
Normal file
1
models/cnn/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
|
|
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .trainer import Trainer
|
||||
from .train import train
|
||||
from utils import print_losses
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
|
|
@ -7,7 +7,7 @@ from torch import nn as nn
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from model.cnn import CNNPredictor
|
||||
from ..models.cnn import CNNPredictor
|
||||
from .train import train
|
||||
|
||||
|
||||
|
|
@ -59,4 +59,4 @@ class OptunaTrainer(Trainer):
|
|||
best_model = CNNPredictor(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, "models/final_model.pt")
|
||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||
Reference in a new issue