commit ef4684ef39d21de791028874946f8f0501c0458c Author: Robin Meersman Date: Fri Nov 7 12:54:36 2025 +0100 feat: initial for IDF diff --git a/integer_discrete_flows/LICENSE b/integer_discrete_flows/LICENSE new file mode 100644 index 0000000..4784841 --- /dev/null +++ b/integer_discrete_flows/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/integer_discrete_flows/README.md b/integer_discrete_flows/README.md new file mode 100644 index 0000000..fa0a33e --- /dev/null +++ b/integer_discrete_flows/README.md @@ -0,0 +1,29 @@ +# Integer Discrete Flows and Lossless Compression + +This repository contains the code for the experiments presented in [1]. + +## Usage + +### CIFAR10 setup: +``` +python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type ‘densenet’ +``` + + +### ImageNet32 setup: +``` +python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99 +``` + + +### ImageNet64 setup: +``` +python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64 +``` + +# Acknowledgements +The Robert Bosch GmbH is acknowledged for financial support. + +# References +[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019). + diff --git a/integer_discrete_flows/coding/__init__.py b/integer_discrete_flows/coding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integer_discrete_flows/coding/coder.py b/integer_discrete_flows/coding/coder.py new file mode 100644 index 0000000..04dd82c --- /dev/null +++ b/integer_discrete_flows/coding/coder.py @@ -0,0 +1,132 @@ +import numpy as np +from . import rans +from utils.distributions import discretized_logistic_cdf, \ + mixture_discretized_logistic_cdf +import torch + +precision = 24 +n_bins = 4096 + + +def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width): + if variable_type == 'discrete': + if distribution_type == 'logistic': + if len(pz) == 2: + return discretized_logistic_cdf( + z, *pz, inverse_bin_width=inverse_bin_width) + elif len(pz) == 3: + return mixture_discretized_logistic_cdf( + z, *pz, inverse_bin_width=inverse_bin_width) + elif distribution_type == 'normal': + pass + + elif variable_type == 'continuous': + if distribution_type == 'logistic': + pass + elif distribution_type == 'normal': + pass + elif distribution_type == 'steplogistic': + pass + raise ValueError + + +def CDF_fn(pz, bin_width, variable_type, distribution_type): + mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2] + MEAN = torch.round(mean / bin_width).long() + + bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None] + bin_locations = bin_locations.float() * bin_width + bin_locations = bin_locations.to(device=pz[0].device) + + pz = [param[:, :, :, :, None] for param in pz] + cdf = cdf_fn( + bin_locations - bin_width, + pz, + variable_type, + distribution_type, + 1./bin_width).cpu().numpy() + + # Compute CDFs, reweigh to give all bins at least + # 1 / (2^precision) probability. + # CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins) + CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \ + + np.arange(n_bins) + + return CDFs, MEAN + + +def encode_sample( + z, pz, variable_type, distribution_type, bin_width=1./256, state=None): + if state is None: + state = rans.x_init + else: + state = rans.unflatten(state) + + CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type) + + # z is transformed to Z to match the indices for the CDFs array + Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN + Z = Z.cpu().numpy() + + if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)): + print('Z out of allowed range of values, canceling compression') + return None + + Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy() + for symbol, cdf in zip(Z[::-1], CDFs[::-1]): + statfun = statfun_encode(cdf) + state = rans.append_symbol(statfun, precision)(state, symbol) + + state = rans.flatten(state) + + return state + + +def decode_sample( + state, pz, variable_type, distribution_type, bin_width=1./256): + state = rans.unflatten(state) + + device = pz[0].device + size = pz[0].size()[0:4] + + CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type) + + CDFs = CDFs.reshape(-1, n_bins) + result = np.zeros(len(CDFs), dtype=int) + for i, cdf in enumerate(CDFs): + statfun = statfun_decode(cdf) + state, symbol = rans.pop_symbol(statfun, precision)(state) + result[i] = symbol + + Z_flat = torch.from_numpy(result).to(device) + Z = Z_flat.view(size) - n_bins // 2 + MEAN + + z = Z.float() * bin_width + + state = rans.flatten(state) + + return state, z + + +def statfun_encode(CDF): + def _statfun_encode(symbol): + return CDF[symbol], CDF[symbol + 1] - CDF[symbol] + return _statfun_encode + + +def statfun_decode(CDF): + def _statfun_decode(cf): + # Search such that CDF[s] <= cf < CDF[s] + s = np.searchsorted(CDF, cf, side='right') + s = s - 1 + start, freq = statfun_encode(CDF)(s) + return s, (start, freq) + return _statfun_decode + + +def encode(x, symbol): + return rans.append_symbol(statfun_encode, precision)(x, symbol) + + +def decode(x): + return rans.pop_symbol(statfun_decode, precision)(x) diff --git a/integer_discrete_flows/coding/rans.py b/integer_discrete_flows/coding/rans.py new file mode 100644 index 0000000..02583a5 --- /dev/null +++ b/integer_discrete_flows/coding/rans.py @@ -0,0 +1,67 @@ +""" +Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h + +ROUGH GUIDE: +We use the pythonic names 'append' and 'pop' for encoding and decoding +respectively. The compressed state 'x' is an immutable stack, implemented using +a cons list. + +x: the current stack-like state of the encoder/decoder. + +precision: the natural numbers are divided into ranges of size 2^precision. + +start & freq: start indicates the beginning of the range in [0, 2^precision-1] +that the current symbol is represented by. freq is the length of the range. +freq is chosen such that p(symbol) ~= freq/2^precision. +""" +import numpy as np +from functools import reduce + + +rans_l = 1 << 31 # the lower bound of the normalisation interval +tail_bits = (1 << 32) - 1 + +x_init = (rans_l, ()) + +def append(x, start, freq, precision): + """Encodes a symbol with range [start, start + freq). All frequencies are + assumed to sum to "1 << precision", and the resulting bits get written to + x.""" + if x[0] >= ((rans_l >> precision) << 32) * freq: + x = (x[0] >> 32, (x[0] & tail_bits, x[1])) + return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1] + +def pop(x_, precision): + """Advances in the bit stream by "popping" a single symbol with range start + "start" and frequency "freq".""" + cf = x_[0] & ((1 << precision) - 1) + def pop(start, freq): + x = freq * (x_[0] >> precision) + cf - start, x_[1] + return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x + return cf, pop + +def append_symbol(statfun, precision): + def append_(x, symbol): + start, freq = statfun(symbol) + return append(x, start, freq, precision) + return append_ + +def pop_symbol(statfun, precision): + def pop_(x): + cf, pop_fun = pop(x, precision) + symbol, (start, freq) = statfun(cf) + return pop_fun(start, freq), symbol + return pop_ + +def flatten(x): + """Flatten a rans state x into a 1d numpy array.""" + out, x = [x[0] >> 32, x[0]], x[1] + while x: + x_head, x = x + out.append(x_head) + return np.asarray(out, dtype=np.uint32) + +def unflatten(arr): + """Unflatten a 1d numpy array into a rans state.""" + return (int(arr[0]) << 32 | int(arr[1]), + reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ())) diff --git a/integer_discrete_flows/experiment_coding.py b/integer_discrete_flows/experiment_coding.py new file mode 100644 index 0000000..1d0cd75 --- /dev/null +++ b/integer_discrete_flows/experiment_coding.py @@ -0,0 +1,188 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +import argparse +import torch +import torch.utils.data +import numpy as np + +from utils.load_data import load_dataset + + +parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') + +parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'], + metavar='DATASET', + help='Dataset choice.') + +parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, + help='disables CUDA training') + +parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') + +parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL', + help='how many batches to wait before logging training status') + +parser.add_argument('--evaluate_interval_epochs', type=int, default=25, + help='Evaluate per how many epochs') + + +# optimization settings +parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS', + help='number of epochs to train (default: 2000)') +parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING', + help='number of early stopping epochs') + +parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE', + help='input batch size for training (default: 100)') +parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE', + help='learning rate') +parser.add_argument('--warmup', type=int, default=10, + help='number of warmup epochs') + +parser.add_argument('--data_augmentation_level', type=int, default=2, + help='data augmentation level') + +parser.add_argument('--no_decode', action='store_true', default=False, + help='disables decoding') + + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} + + +def encode_images(img, model, decode): + batchsize, img_c, img_h, img_w = img.size() + c, h, w = model.args.input_size + + assert img_h == img_w and h == w + + if img_h != h: + assert img_h % h == 0 + steps = img_h // h + + states = [[] for i in range(batchsize)] + state_sizes = [0 for i in range(batchsize)] + bpd = [0 for i in range(batchsize)] + error = 0 + + for j in range(steps): + for i in range(steps): + r = encode_patches( + img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode) + for b in range(batchsize): + + if r[0][b] is None: + states[b].append(None) + else: + states[b].extend(r[0][b]) + state_sizes[b] += r[1][b] + bpd[b] += r[2][b] / steps**2 + error += r[3] + return states, state_sizes, bpd, error + else: + return encode_patches(img, model, decode) + + +def encode_patches(imgs, model, decode): + batchsize, img_c, img_h, img_w = imgs.size() + c, h, w = model.args.input_size + assert img_h == h and img_w == w + + states = model.encode(imgs) + + bpd = model.forward(imgs)[1].cpu().numpy() + + state_sizes = [] + error = 0 + + for b in range(batchsize): + if states[b] is None: + # Using escape bit ;) + state_sizes += [8 * img_c * img_h * img_w + 1] + + # Error remains unchanged. + print('Escaping, not encoding.') + + else: + if decode: + x_recon = model.decode([states[b]]) + + error += torch.sum( + torch.abs(x_recon.int() - imgs[b].int())).item() + + # Append state plus an escape bit + state_sizes += [32 * len(states[b]) + 1] + + return states, state_sizes, bpd, error + + +def run(args, kwargs): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + args.snap_dir = snap_dir = \ + 'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/' + + # ================================================================================================================== + # SNAPSHOTS + # ================================================================================================================== + + # ================================================================================================================== + # LOAD DATA + # ================================================================================================================== + train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) + + final_model = torch.load(snap_dir + 'a.model') + if hasattr(final_model, 'module'): + final_model = final_model.module + final_model = final_model.cuda() + + sizes = [] + errors = [] + bpds = [] + + import time + start = time.time() + + t = 0 + with torch.no_grad(): + for data, _ in test_loader: + if args.cuda: + data = data.cuda() + + state, state_sizes, bpd, error = \ + encode_images(data, final_model, decode=not args.no_decode) + + errors += [error] + bpds.extend(bpd) + sizes.extend(state_sizes) + + t += len(data) + + print( + 'Examples: {}/{} bpd compression: {:.3f} error: {},' + ' analytical bpd {:.3f}'.format( + t, len(test_loader.dataset), + np.mean(sizes) / np.prod(data.size()[1:]), + np.sum(errors), + np.mean(bpds) + )) + + if args.no_decode: + print('Not testing decoding.') + else: + print('Error: {}'.format(np.sum(errors))) + + print('Took {:.3f} seconds / example'.format((time.time() - start) / t)) + print('Final bpd: {:.3f} error: {}'.format( + np.mean(sizes) / np.prod(data.size()[1:]), + np.sum(errors))) + + +if __name__ == "__main__": + + run(args, kwargs) diff --git a/integer_discrete_flows/experiment_progressive_loading.py b/integer_discrete_flows/experiment_progressive_loading.py new file mode 100644 index 0000000..cbd01cd --- /dev/null +++ b/integer_discrete_flows/experiment_progressive_loading.py @@ -0,0 +1,105 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +import argparse +import time +import torch +import torch.utils.data +import torch.optim as optim +import numpy as np +import matplotlib.pyplot as plt +from torchvision.utils import make_grid + +import os +from optimization.training import train, evaluate +from utils.load_data import load_dataset +from utils.plotting import plot_training_curve +import imageio + + +parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') + +parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'], + metavar='DATASET', + help='Dataset choice.') + +parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE', + help='input batch size for training (default: 100)') + +parser.add_argument('--data_augmentation_level', type=int, default=2, + help='data augmentation level') + +parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, + help='disables CUDA training') + + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} + + +def run(args, kwargs): + + args.snap_dir = snap_dir = \ + 'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/' + + # ================================================================================================================== + # SNAPSHOTS + # ================================================================================================================== + + # ================================================================================================================== + # LOAD DATA + # ================================================================================================================== + train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) + + final_model = torch.load(snap_dir + 'a.model') + if hasattr(final_model, 'module'): + final_model = final_model.module + + from models.backround import SmoothRound + for module in final_model.modules(): + if isinstance(module, SmoothRound): + module._round_decay = 1. + + exp_dir = snap_dir + 'partials/' + os.makedirs(exp_dir, exist_ok=True) + + images = [] + with torch.no_grad(): + for data, _ in test_loader: + + if args.cuda: + data = data.cuda() + + for i in range(len(data)): + _, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1]) + + for j in range(len(ys) + 1): + x_recon = final_model.inverse( + z, + ys[len(ys) - j:]) + + images.append(x_recon.float()) + + if i == 10: + break + break + + for j in range(len(ys) + 1): + + grid = make_grid( + torch.stack(images[j::len(ys) + 1], dim=0).squeeze(), + nrow=11, padding=0, + normalize=True, range=None, + scale_each=False, pad_value=0) + + imageio.imwrite( + exp_dir + 'loaded{j}.png'.format(j=j), + grid.cpu().numpy().transpose(1, 2, 0)) + + +if __name__ == "__main__": + + run(args, kwargs) diff --git a/integer_discrete_flows/main_experiment.py b/integer_discrete_flows/main_experiment.py new file mode 100644 index 0000000..330d280 --- /dev/null +++ b/integer_discrete_flows/main_experiment.py @@ -0,0 +1,285 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +import argparse +import time +import torch +import torch.utils.data +import torch.optim as optim +import numpy as np +import math +import random + +import os + +import datetime + +from optimization.training import train, evaluate +from utils.load_data import load_dataset + +parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') + +parser.add_argument('-d', '--dataset', type=str, default='cifar10', + choices=['cifar10', 'imagenet32', 'imagenet64'], + metavar='DATASET', + help='Dataset choice.') + +parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, + help='disables CUDA training') + +parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') + +parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL', + help='how many batches to wait before logging training status') + +parser.add_argument('--evaluate_interval_epochs', type=int, default=25, + help='Evaluate per how many epochs') + +parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR', + help='output directory for model snapshots etc.') + +fp = parser.add_mutually_exclusive_group(required=False) +fp.add_argument('-te', '--testing', action='store_true', dest='testing', + help='evaluate on test set after training') +fp.add_argument('-va', '--validation', action='store_false', dest='testing', + help='only evaluate on validation set') +parser.set_defaults(testing=True) + +# optimization settings +parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS', + help='number of epochs to train (default: 2000)') +parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING', + help='number of early stopping epochs') + +parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE', + help='input batch size for training (default: 100)') +parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE', + help='learning rate') +parser.add_argument('--warmup', type=int, default=10, + help='number of warmup epochs') + +parser.add_argument('--data_augmentation_level', type=int, default=2, + help='data augmentation level') + +parser.add_argument('--variable_type', type=str, default='discrete', + help='variable type of data distribution: discrete/continuous', + choices=['discrete', 'continuous']) +parser.add_argument('--distribution_type', type=str, default='logistic', + choices=['logistic', 'normal', 'steplogistic'], + help='distribution type: logistic/normal') +parser.add_argument('--n_flows', type=int, default=8, + help='number of flows per level') +parser.add_argument('--n_levels', type=int, default=3, + help='number of levels') + +parser.add_argument('--n_bits', type=int, default=8, + help='') + +# ---------------- SETTINGS CONCERNING NETWORKS ------------- +parser.add_argument('--densenet_depth', type=int, default=8, + help='Depth of densenets') +parser.add_argument('--n_channels', type=int, default=512, + help='number of channels in coupling and splitprior') +# ---------------- ----------------------------- ------------- + + +# ---------------- SETTINGS CONCERNING COUPLING LAYERS ------------- +parser.add_argument('--coupling_type', type=str, default='shallow', + choices=['shallow', 'resnet', 'densenet'], + help='Type of coupling layer') +parser.add_argument('--splitfactor', default=0, type=int, + help='Split factor for coupling layers.') + +parser.add_argument('--split_quarter', dest='split_quarter', action='store_true', + help='Split coupling layer on quarter') +parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false') +parser.set_defaults(split_quarter=True) +# ---------------- ----------------------------------- ------------- + + +# ---------------- SETTINGS CONCERNING SPLITPRIORS ------------- +parser.add_argument('--splitprior_type', type=str, default='shallow', + choices=['none', 'shallow', 'resnet', 'densenet'], + help='Type of splitprior. Use \'none\' for no splitprior') +# ---------------- ------------------------------- ------------- + + +# ---------------- SETTINGS CONCERNING PRIORS ------------- +parser.add_argument('--n_mixtures', type=int, default=1, + help='number of mixtures') +# ---------------- ------------------------------- ------------- + +parser.add_argument('--hard_round', dest='hard_round', action='store_true', + help='Rounding of translation in discrete models. Weird ' + 'probabilistic implications, only for experimental phase') +parser.add_argument('--no_hard_round', dest='hard_round', action='store_false') +parser.set_defaults(hard_round=True) + +parser.add_argument('--round_approx', type=str, default='smooth', + choices=['smooth', 'stochastic']) + +parser.add_argument('--lr_decay', default=0.999, type=float, + help='Learning rate') + +parser.add_argument('--temperature', default=1.0, type=float, + help='Temperature used for BackRound. It is used in ' + 'the the SmoothRound module. ' + '(default=1.0') + +# gpu/cpu +parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU', + help='choose GPU to run on.') + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +if args.manual_seed is None: + args.manual_seed = random.randint(1, 100000) +random.seed(args.manual_seed) +torch.manual_seed(args.manual_seed) +np.random.seed(args.manual_seed) + + +kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} + + +def run(args, kwargs): + + print('\nMODEL SETTINGS: \n', args, '\n') + print("Random Seed: ", args.manual_seed) + + if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5: + args.evaluate_interval_epochs = 5 + + # ================================================================================================================== + # SNAPSHOTS + # ================================================================================================================== + args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') + args.model_signature = args.model_signature.replace(':', '_') + + snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset) + snap_dir = snapshots_path + + snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels) + + snap_dir = snap_dir + '__' + args.model_signature + '/' + + args.snap_dir = snap_dir + + if not os.path.exists(snap_dir): + os.makedirs(snap_dir) + + with open(snap_dir + 'log.txt', 'a') as ff: + print('\nMODEL SETTINGS: \n', args, '\n', file=ff) + + # SAVING + torch.save(args, snap_dir + '.config') + + # ================================================================================================================== + # LOAD DATA + # ================================================================================================================== + train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) + + # ================================================================================================================== + # SELECT MODEL + # ================================================================================================================== + # flow parameters and architecture choice are passed on to model through args + print(args.input_size) + + import models.Model as Model + + model = Model.Model(args) + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.set_temperature(args.temperature) + model.enable_hard_round(args.hard_round) + + model_sample = model + + # ==================================== + # INIT + # ==================================== + # data dependend initialization on CPU + for batch_idx, (data, _) in enumerate(train_loader): + model(data) + break + + if torch.cuda.device_count() > 1: + print("Let's use", torch.cuda.device_count(), "GPUs!") + model = torch.nn.DataParallel(model, dim=0) + + model.to(args.device) + + def lr_lambda(epoch): + return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch) + optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) + + # ================================================================================================================== + # TRAINING + # ================================================================================================================== + train_bpd = [] + val_bpd = [] + + # for early stopping + best_val_bpd = np.inf + best_train_bpd = np.inf + epoch = 0 + + train_times = [] + + model.eval() + model.train() + + for epoch in range(1, args.epochs + 1): + t_start = time.time() + scheduler.step() + tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args) + train_bpd.append(tr_bpd) + train_times.append(time.time()-t_start) + print('One training epoch took %.2f seconds' % (time.time()-t_start)) + + if epoch < 25 or epoch % args.evaluate_interval_epochs == 0: + v_loss, v_bpd = evaluate( + train_loader, val_loader, model, model_sample, args, + epoch=epoch, file=snap_dir + 'log.txt') + + val_bpd.append(v_bpd) + + # Model save based on TRAIN performance (is heavily correlated with validation performance.) + if np.mean(tr_bpd) < best_train_bpd: + best_train_bpd = np.mean(tr_bpd) + best_val_bpd = v_bpd + torch.save(model.module, snap_dir + 'a.model') + torch.save(optimizer, snap_dir + 'a.optimizer') + print('->model saved<-') + + print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format( + best_train_bpd, best_val_bpd)) + + if math.isnan(v_loss): + raise ValueError('NaN encountered!') + + train_bpd = np.hstack(train_bpd) + val_bpd = np.array(val_bpd) + + # training time per epoch + train_times = np.array(train_times) + mean_train_time = np.mean(train_times) + std_train_time = np.std(train_times, ddof=1) + print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) + + # ================================================================================================================== + # EVALUATION + # ================================================================================================================== + final_model = torch.load(snap_dir + 'a.model') + test_loss, test_bpd = evaluate( + train_loader, test_loader, final_model, final_model, args, + epoch=epoch, file=snap_dir + 'test_log.txt') + + print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd)) + + +if __name__ == "__main__": + + run(args, kwargs) diff --git a/integer_discrete_flows/media/IDF_poster.pdf b/integer_discrete_flows/media/IDF_poster.pdf new file mode 100644 index 0000000..a1c512d Binary files /dev/null and b/integer_discrete_flows/media/IDF_poster.pdf differ diff --git a/integer_discrete_flows/media/IDF_slides.pdf b/integer_discrete_flows/media/IDF_slides.pdf new file mode 100644 index 0000000..3fc439e Binary files /dev/null and b/integer_discrete_flows/media/IDF_slides.pdf differ diff --git a/integer_discrete_flows/models/Model.py b/integer_discrete_flows/models/Model.py new file mode 100644 index 0000000..67bb4ac --- /dev/null +++ b/integer_discrete_flows/models/Model.py @@ -0,0 +1,191 @@ +import torch +import models.generative_flows as generative_flows +import numpy as np +from models.utils import Base +from .priors import Prior +from optimization.loss import compute_loss_array +from coding.coder import encode_sample, decode_sample + + +class Normalize(Base): + def __init__(self, args): + super().__init__() + self.n_bits = args.n_bits + self.variable_type = args.variable_type + self.input_size = args.input_size + + def forward(self, x, ldj, reverse=False): + domain = 2.**self.n_bits + + if self.variable_type == 'discrete': + # Discrete variables will be measured on intervals sized 1/domain. + # Hence, there is no need to change the log Jacobian determinant. + dldj = 0 + elif self.variable_type == 'continuous': + dldj = -np.log(domain) * np.prod(self.input_size) + else: + raise ValueError + + if not reverse: + x = (x - domain / 2) / domain + ldj += dldj + else: + x = x * domain + domain / 2 + ldj -= dldj + + return x, ldj + + +class Model(Base): + """ + The base VAE class containing gated convolutional encoder and decoder + architecture. Can be used as a base class for VAE's with normalizing flows. + """ + + def __init__(self, args): + super().__init__() + self.args = args + self.variable_type = args.variable_type + self.distribution_type = args.distribution_type + + n_channels, height, width = args.input_size + + self.normalize = Normalize(args) + + self.flow = generative_flows.GenerativeFlow( + n_channels, height, width, args) + + self.n_bits = args.n_bits + + self.z_size = self.flow.z_size + + self.prior = Prior(self.z_size, args) + + def dequantize(self, x): + if self.training: + x = x + torch.rand_like(x) + else: + # Required for stability. + alpha = 1e-3 + x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha) + + return x + + def loss(self, pz, z, pys, ys, ldj): + batchsize = z.size(0) + loss, bpd, bpd_per_prior = \ + compute_loss_array(pz, z, pys, ys, ldj, self.args) + + for module in self.modules(): + if hasattr(module, 'auxillary_loss'): + loss += module.auxillary_loss() / batchsize + + return loss, bpd, bpd_per_prior + + def forward(self, x): + """ + Evaluates the model as a whole, encodes and decodes. Note that the log + det jacobian is zero for a plain VAE (without flows), and z_0 = z_k. + """ + # Decode z to x. + + assert x.dtype == torch.uint8 + + x = x.float() + + ldj = torch.zeros_like(x[:, 0, 0, 0]) + if self.variable_type == 'continuous': + x = self.dequantize(x) + elif self.variable_type == 'discrete': + pass + else: + raise ValueError + + x, ldj = self.normalize(x, ldj) + + z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=()) + + pz, z, ldj = self.prior(z, ldj) + + loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj) + + return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj + + def inverse(self, z, ys): + ldj = torch.zeros_like(z[:, 0, 0, 0]) + x, ldj, pys, py = \ + self.flow(z, ldj, pys=[], ys=ys, reverse=True) + + x, ldj = self.normalize(x, ldj, reverse=True) + + x_uint8 = torch.clamp(x, min=0, max=255).to( + torch.uint8) + + return x_uint8 + + def sample(self, n): + z_sample = self.prior.sample(n) + + ldj = torch.zeros_like(z_sample[:, 0, 0, 0]) + x_sample, ldj, pys, py = \ + self.flow(z_sample, ldj, pys=[], ys=[], reverse=True) + + x_sample, ldj = self.normalize(x_sample, ldj, reverse=True) + + x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to( + torch.uint8) + + return x_sample_uint8 + + def encode(self, x): + batchsize = x.size(0) + _, _, _, pz, z, pys, ys, _ = self.forward(x) + + pjs = list(pys) + [pz] + js = list(ys) + [z] + + states = [] + + for b in range(batchsize): + state = None + for pj, j in zip(pjs, js): + pj_b = [param[b:b+1] for param in pj] + j_b = j[b:b+1] + + state = encode_sample( + j_b, pj_b, self.variable_type, + self.distribution_type, state=state) + if state is None: + break + + states.append(state) + + return states + + def decode(self, states): + def decode_fn(states, pj): + states = list(states) + j = [] + + for b in range(len(states)): + pj_b = [param[b:b+1] for param in pj] + + states[b], j_b = decode_sample( + states[b], pj_b, self.variable_type, + self.distribution_type) + + j.append(j_b) + + j = torch.cat(j, dim=0) + return states, j + + states, z = self.prior.decode(states, decode_fn=decode_fn) + + ldj = torch.zeros_like(z[:, 0, 0, 0]) + + x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn) + + x, ldj = self.normalize(x, ldj, reverse=True) + x = x.to(dtype=torch.uint8) + + return x diff --git a/integer_discrete_flows/models/__init__.py b/integer_discrete_flows/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integer_discrete_flows/models/backround.py b/integer_discrete_flows/models/backround.py new file mode 100644 index 0000000..c6296ba --- /dev/null +++ b/integer_discrete_flows/models/backround.py @@ -0,0 +1,151 @@ +import torch +import torch.nn.functional as F +import numpy as np + +from models.utils import Base + + +class RoundStraightThrough(torch.autograd.Function): + + def __init__(self): + super().__init__() + + @staticmethod + def forward(ctx, input): + rounded = torch.round(input, out=None) + return rounded + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input + + +_round_straightthrough = RoundStraightThrough().apply + + +def _stacked_sigmoid(x, temperature, n_approx=3): + + x_ = x - 0.5 + rounded = torch.round(x_) + x_remainder = x_ - rounded + + size = x_.size() + x_remainder = x_remainder.view(size + (1,)) + + translation = torch.arange(n_approx) - n_approx // 2 + translation = translation.to(device=x.device, dtype=x.dtype) + translation = translation.view([1] * len(size) + [len(translation)]) + out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1) + + return out + rounded - (n_approx // 2) + + +class SmoothRound(Base): + def __init__(self): + self._temperature = None + self._n_approx = None + super().__init__() + self.hard_round = None + + @property + def temperature(self): + return self._temperature + + @temperature.setter + def temperature(self, value): + self._temperature = value + + if self._temperature <= 0.05: + self._n_approx = 1 + elif 0.05 < self._temperature < 0.13: + self._n_approx = 3 + else: + self._n_approx = 5 + + def forward(self, x): + assert self._temperature is not None + assert self._n_approx is not None + assert self.hard_round is not None + + if self.temperature <= 0.25: + h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx) + else: + h = x + + if self.hard_round: + h = _round_straightthrough(h) + + return h + + +class StochasticRound(Base): + def __init__(self): + super().__init__() + self.hard_round = None + + def forward(self, x): + u = torch.rand_like(x) + + h = x + u - 0.5 + + if self.hard_round: + h = _round_straightthrough(h) + + return h + + +class BackRound(Base): + + def __init__(self, args, inverse_bin_width): + """ + BackRound is an approximation to Round that allows for Backpropagation. + + Approximate the round function using a sum of translated sigmoids. + The temperature determines how well the round function is approximated, + i.e., a lower temperature corresponds to a better approximation, at + the cost of more vanishing gradients. + + BackRound supports the following settings: + * By setting hard to True and temperature > 0.25, BackRound + reduces to a round function with a straight through gradient + estimator + * When using 0 < temperature <= 0.25 and hard = True, the + output in the forward pass is equivalent to a round function, but the + gradient is approximated by the gradient of a sum of sigmoids. + * When using hard = False, the output is not constrained to integers. + * When temperature > 0.25 and hard = False, BackRound reduces to + the identity function. + + Arguments + --------- + temperature: float + Temperature used for stacked sigmoid approximated. If temperature + is greater than 0.25, the approximation reduces to the indentiy + function. + hard: bool + If hard is True, a (hard) round is applied before returning. The + gradient for this is approximated using the straight-through + estimator. + """ + super().__init__() + self.inverse_bin_width = inverse_bin_width + self.round_approx = args.round_approx + + if args.round_approx == 'smooth': + self.round = SmoothRound() + elif args.round_approx == 'stochastic': + self.round = StochasticRound() + else: + raise ValueError + + def forward(self, x): + if self.round_approx == 'smooth' or self.round_approx == 'stochastic': + h = x * self.inverse_bin_width + + h = self.round(h) + + return h / self.inverse_bin_width + + else: + raise ValueError diff --git a/integer_discrete_flows/models/coupling.py b/integer_discrete_flows/models/coupling.py new file mode 100644 index 0000000..8645449 --- /dev/null +++ b/integer_discrete_flows/models/coupling.py @@ -0,0 +1,142 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import numpy as np +from models.utils import Base +from .backround import BackRound +from .networks import NN + + +UNIT_TESTING = False + + +class SplitFactorCoupling(Base): + def __init__(self, c_in, factor, height, width, args): + super().__init__() + self.n_channels = args.n_channels + self.kernel = 3 + self.input_channel = c_in + self.round_approx = args.round_approx + + if args.variable_type == 'discrete': + self.round = BackRound( + args, inverse_bin_width=2**args.n_bits) + else: + self.round = None + + self.split_idx = c_in - (c_in // factor) + + self.nn = NN( + args=args, + c_in=self.split_idx, + c_out=c_in - self.split_idx, + height=height, + width=width, + kernel=self.kernel, + nn_type=args.coupling_type) + + def forward(self, z, ldj, reverse=False): + z1 = z[:, :self.split_idx, :, :] + z2 = z[:, self.split_idx:, :, :] + + t = self.nn(z1) + + if self.round is not None: + t = self.round(t) + + if not reverse: + z2 = z2 + t + else: + z2 = z2 - t + + z = torch.cat([z1, z2], dim=1) + + return z, ldj + + +class Coupling(Base): + def __init__(self, c_in, height, width, args): + super().__init__() + + if args.split_quarter: + factor = 4 + elif args.splitfactor > 1: + factor = args.splitfactor + else: + factor = 2 + + self.coupling = SplitFactorCoupling( + c_in, factor, height, width, args=args) + + def forward(self, z, ldj, reverse=False): + return self.coupling(z, ldj, reverse) + + +def test_generative_flow(): + import models.networks as networks + global UNIT_TESTING + + networks.UNIT_TESTING = True + UNIT_TESTING = True + + batch_size = 17 + + input_size = [12, 16, 16] + + class Args(): + def __init__(self): + self.input_size = input_size + self.learn_split = False + self.variable_type = 'continuous' + self.distribution_type = 'logistic' + self.round_approx = 'smooth' + self.coupling_type = 'shallow' + self.conv_type = 'standard' + self.densenet_depth = 8 + self.bottleneck = False + self.n_channels = 512 + self.network1x1 = 'standard' + self.auxilary_freq = -1 + self.actnorm = False + self.LU = False + self.coupling_lifting_L = True + self.splitprior = True + self.split_quarter = True + self.n_levels = 2 + self.n_flows = 2 + self.cond_L = True + self.n_bits = True + + args = Args() + + x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256. + ldj = torch.zeros_like(x[:, 0, 0, 0]) + + model = Coupling(c_in=12, height=16, width=16, args=args) + + print(model) + + model.set_temperature(1.) + model.enable_hard_round() + + model.eval() + + z, ldj = model(x, ldj, reverse=False) + + # Check if gradient computation works + loss = torch.sum(z**2) + loss.backward() + + recon, ldj = model(z, ldj, reverse=True) + + sse = torch.sum(torch.pow(x - recon, 2)).item() + ae = torch.abs(x - recon).sum() + print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae)) + + +if __name__ == '__main__': + test_generative_flow() diff --git a/integer_discrete_flows/models/generative_flows.py b/integer_discrete_flows/models/generative_flows.py new file mode 100644 index 0000000..8b7bcca --- /dev/null +++ b/integer_discrete_flows/models/generative_flows.py @@ -0,0 +1,175 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import numpy as np +from models.utils import Base +from .priors import SplitPrior +from .coupling import Coupling + + +UNIT_TESTING = False + + +def space_to_depth(x): + xs = x.size() + # Pick off every second element + x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2) + # Transpose picked elements next to channels. + x = x.permute((0, 1, 3, 5, 2, 4)).contiguous() + # Combine with channels. + x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2) + return x + + +def depth_to_space(x): + xs = x.size() + # Pick off elements from channels + x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3]) + # Transpose picked elements next to HW dimensions. + x = x.permute((0, 1, 4, 2, 5, 3)).contiguous() + # Combine with HW dimensions. + x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2) + return x + + +def int_shape(x): + return list(map(int, x.size())) + + +class Flatten(Base): + def forward(self, x): + return x.view(x.size(0), -1) + + +class Reshape(Base): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, x): + return x.view(x.size(0), *self.shape) + + +class Reverse(Base): + def __init__(self): + super().__init__() + + def forward(self, z, reverse=False): + flip_idx = torch.arange(z.size(1) - 1, -1, -1).long() + z = z[:, flip_idx, :, :] + return z + + +class Permute(Base): + def __init__(self, n_channels): + super().__init__() + + permutation = np.arange(n_channels, dtype='int') + np.random.shuffle(permutation) + + permutation_inv = np.zeros(n_channels, dtype='int') + permutation_inv[permutation] = np.arange(n_channels, dtype='int') + + self.permutation = torch.from_numpy(permutation) + self.permutation_inv = torch.from_numpy(permutation_inv) + + def forward(self, z, ldj, reverse=False): + if not reverse: + z = z[:, self.permutation, :, :] + else: + z = z[:, self.permutation_inv, :, :] + + return z, ldj + + def InversePermute(self): + inv_permute = Permute(len(self.permutation)) + inv_permute.permutation = self.permutation_inv + inv_permute.permutation_inv = self.permutation + return inv_permute + + +class Squeeze(Base): + def __init__(self): + super().__init__() + + def forward(self, z, ldj, reverse=False): + if not reverse: + z = space_to_depth(z) + else: + z = depth_to_space(z) + return z, ldj + + +class GenerativeFlow(Base): + def __init__(self, n_channels, height, width, args): + super().__init__() + layers = [] + layers.append(Squeeze()) + n_channels *= 4 + height //= 2 + width //= 2 + + for level in range(args.n_levels): + + for i in range(args.n_flows): + perm_layer = Permute(n_channels) + layers.append(perm_layer) + + layers.append( + Coupling(n_channels, height, width, args)) + + if level < args.n_levels - 1: + if args.splitprior_type != 'none': + # Standard splitprior + factor_out = n_channels // 2 + layers.append(SplitPrior(n_channels, factor_out, height, width, args)) + n_channels = n_channels - factor_out + + layers.append(Squeeze()) + n_channels *= 4 + height //= 2 + width //= 2 + + self.layers = torch.nn.ModuleList(layers) + self.z_size = (n_channels, height, width) + + def forward(self, z, ldj, pys=(), ys=(), reverse=False): + if not reverse: + for l, layer in enumerate(self.layers): + if isinstance(layer, (SplitPrior)): + py, y, z, ldj = layer(z, ldj) + pys += (py,) + ys += (y,) + + else: + z, ldj = layer(z, ldj) + + else: + for l, layer in reversed(list(enumerate(self.layers))): + if isinstance(layer, (SplitPrior)): + if len(ys) > 0: + z, ldj = layer.inverse(z, ldj, y=ys[-1]) + # Pop last element + ys = ys[:-1] + else: + z, ldj = layer.inverse(z, ldj, y=None) + + else: + z, ldj = layer(z, ldj, reverse=True) + + return z, ldj, pys, ys + + def decode(self, z, ldj, state, decode_fn): + + for l, layer in reversed(list(enumerate(self.layers))): + if isinstance(layer, SplitPrior): + z, ldj, state = layer.decode(z, ldj, state, decode_fn) + + else: + z, ldj = layer(z, ldj, reverse=True) + + return z, ldj diff --git a/integer_discrete_flows/models/networks.py b/integer_discrete_flows/models/networks.py new file mode 100644 index 0000000..480f4df --- /dev/null +++ b/integer_discrete_flows/models/networks.py @@ -0,0 +1,154 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.utils import Base + + +UNIT_TESTING = False + + +class Conv2dReLU(Base): + def __init__( + self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0, + bias=True): + super().__init__() + + self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding) + + def forward(self, x): + h = self.nn(x) + + y = F.relu(h) + + return y + + +class ResidualBlock(Base): + def __init__(self, n_channels, kernel, Conv2dAct): + super().__init__() + + self.nn = torch.nn.Sequential( + Conv2dAct(n_channels, n_channels, kernel, padding=1), + torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1), + ) + + def forward(self, x): + h = self.nn(x) + h = F.relu(h + x) + return h + + +class DenseLayer(Base): + def __init__(self, args, n_inputs, growth, Conv2dAct): + super().__init__() + + conv1x1 = Conv2dAct( + n_inputs, n_inputs, kernel_size=1, stride=1, + padding=0, bias=True) + + self.nn = torch.nn.Sequential( + conv1x1, + Conv2dAct( + n_inputs, growth, kernel_size=3, stride=1, + padding=1, bias=True), + ) + + def forward(self, x): + h = self.nn(x) + + h = torch.cat([x, h], dim=1) + return h + + +class DenseBlock(Base): + def __init__( + self, args, n_inputs, n_outputs, kernel, Conv2dAct): + super().__init__() + depth = args.densenet_depth + + future_growth = n_outputs - n_inputs + + layers = [] + + for d in range(depth): + growth = future_growth // (depth - d) + + layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct)) + n_inputs += growth + future_growth -= growth + + self.nn = torch.nn.Sequential(*layers) + + def forward(self, x): + return self.nn(x) + + +class Identity(Base): + def __init__(self): + super.__init__() + + def forward(self, x): + return x + + +class NN(Base): + def __init__( + self, args, c_in, c_out, height, width, nn_type, kernel=3): + super().__init__() + + Conv2dAct = Conv2dReLU + n_channels = args.n_channels + + if nn_type == 'shallow': + + if args.network1x1 == 'standard': + conv1x1 = Conv2dAct( + n_channels, n_channels, kernel_size=1, stride=1, + padding=0, bias=False) + + layers = [ + Conv2dAct(c_in, n_channels, kernel, padding=1), + conv1x1] + + layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)] + + elif nn_type == 'resnet': + layers = [ + Conv2dAct(c_in, n_channels, kernel, padding=1), + ResidualBlock(n_channels, kernel, Conv2dAct), + ResidualBlock(n_channels, kernel, Conv2dAct)] + + layers += [ + torch.nn.Conv2d(n_channels, c_out, kernel, padding=1) + ] + + elif nn_type == 'densenet': + layers = [ + DenseBlock( + args=args, + n_inputs=c_in, + n_outputs=n_channels + c_in, + kernel=kernel, + Conv2dAct=Conv2dAct)] + + layers += [ + torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1) + ] + else: + raise ValueError + + self.nn = torch.nn.Sequential(*layers) + + # Set parameters of last conv-layer to zero. + if not UNIT_TESTING: + self.nn[-1].weight.data.zero_() + self.nn[-1].bias.data.zero_() + + def forward(self, x): + return self.nn(x) diff --git a/integer_discrete_flows/models/priors.py b/integer_discrete_flows/models/priors.py new file mode 100644 index 0000000..5ace8e9 --- /dev/null +++ b/integer_discrete_flows/models/priors.py @@ -0,0 +1,164 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import torch.nn.functional as F +from torch.nn import Parameter +from utils.distributions import sample_discretized_logistic, \ + sample_mixture_discretized_logistic, sample_normal, sample_logistic, \ + sample_discretized_normal, sample_mixture_normal +from models.utils import Base +from .networks import NN + + +def sample_prior(px, variable_type, distribution_type, inverse_bin_width): + if variable_type == 'discrete': + if distribution_type == 'logistic': + if len(px) == 2: + return sample_discretized_logistic( + *px, inverse_bin_width=inverse_bin_width) + elif len(px) == 3: + return sample_mixture_discretized_logistic( + *px, inverse_bin_width=inverse_bin_width) + + elif distribution_type == 'normal': + return sample_discretized_normal( + *px, inverse_bin_width=inverse_bin_width) + + elif variable_type == 'continuous': + if distribution_type == 'logistic': + return sample_logistic(*px) + elif distribution_type == 'normal': + if len(px) == 2: + return sample_normal(*px) + elif len(px) == 3: + return sample_mixture_normal(*px) + elif distribution_type == 'steplogistic': + return sample_logistic(*px) + + raise ValueError + + +class Prior(Base): + def __init__(self, size, args): + super().__init__() + c, h, w = size + + self.inverse_bin_width = 2**args.n_bits + self.variable_type = args.variable_type + self.distribution_type = args.distribution_type + self.n_mixtures = args.n_mixtures + + if self.n_mixtures == 1: + self.mu = Parameter(torch.Tensor(c, h, w)) + self.logs = Parameter(torch.Tensor(c, h, w)) + elif self.n_mixtures > 1: + self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) + self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) + self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) + + self.reset_parameters() + + def reset_parameters(self): + self.mu.data.zero_() + + if self.n_mixtures > 1: + self.pi_logit.data.zero_() + for i in range(self.n_mixtures): + self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2. + + self.logs.data.zero_() + + def get_pz(self, n): + if self.n_mixtures == 1: + mu = self.mu.repeat(n, 1, 1, 1) + logs = self.logs.repeat(n, 1, 1, 1) # scaling scale + return mu, logs + + elif self.n_mixtures > 1: + pi = F.softmax(self.pi_logit, dim=-1) + mu = self.mu.repeat(n, 1, 1, 1, 1) + logs = self.logs.repeat(n, 1, 1, 1, 1) + pi = pi.repeat(n, 1, 1, 1, 1) + return mu, logs, pi + + def forward(self, z, ldj): + pz = self.get_pz(z.size(0)) + + return pz, z, ldj + + def sample(self, n): + pz = self.get_pz(n) + + z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width) + + return z_sample + + def decode(self, states, decode_fn): + pz = self.get_pz(n=len(states)) + + states, z = decode_fn(states, pz) + return states, z + + +class SplitPrior(Base): + def __init__(self, c_in, factor_out, height, width, args): + super().__init__() + + self.split_idx = c_in - factor_out + self.inverse_bin_width = 2**args.n_bits + self.variable_type = args.variable_type + self.distribution_type = args.distribution_type + self.input_channel = c_in + + self.nn = NN( + args=args, + c_in=c_in - factor_out, + c_out=factor_out * 2, + height=height, + width=width, + nn_type=args.splitprior_type) + + def get_py(self, z): + h = self.nn(z) + mu = h[:, ::2, :, :] + logs = h[:, 1::2, :, :] + + py = [mu, logs] + + return py + + def split(self, z): + z1 = z[:, :self.split_idx, :, :] + y = z[:, self.split_idx:, :, :] + return z1, y + + def combine(self, z, y): + result = torch.cat([z, y], dim=1) + + return result + + def forward(self, z, ldj): + z, y = self.split(z) + + py = self.get_py(z) + + return py, y, z, ldj + + def inverse(self, z, ldj, y): + # Sample if y is not given. + if y is None: + py = self.get_py(z) + y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width) + + z = self.combine(z, y) + + return z, ldj + + def decode(self, z, ldj, states, decode_fn): + py = self.get_py(z) + states, y = decode_fn(states, py) + return self.combine(z, y), ldj, states diff --git a/integer_discrete_flows/models/utils.py b/integer_discrete_flows/models/utils.py new file mode 100644 index 0000000..958f760 --- /dev/null +++ b/integer_discrete_flows/models/utils.py @@ -0,0 +1,36 @@ +import torch + + +class Base(torch.nn.Module): + """ + The base class for modules. That contains a disable round mode + """ + + def __init__(self): + super().__init__() + + def _set_child_attribute(self, attr, value): + r"""Sets the module in rounding mode. + + This has any effect only on certain modules if variable type is + discrete. + + Returns: + Module: self + """ + if hasattr(self, attr): + setattr(self, attr, value) + + for module in self.modules(): + if hasattr(module, attr): + setattr(module, attr, value) + return self + + def set_temperature(self, value): + self._set_child_attribute("temperature", value) + + def enable_hard_round(self, mode=True): + self._set_child_attribute("hard_round", mode) + + def disable_hard_round(self, mode=True): + self.enable_hard_round(not mode) diff --git a/integer_discrete_flows/optimization/__init__.py b/integer_discrete_flows/optimization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integer_discrete_flows/optimization/loss.py b/integer_discrete_flows/optimization/loss.py new file mode 100644 index 0000000..6cdd704 --- /dev/null +++ b/integer_discrete_flows/optimization/loss.py @@ -0,0 +1,148 @@ +from __future__ import print_function + +import numpy as np +import torch +from utils.distributions import log_discretized_logistic, \ + log_mixture_discretized_logistic, log_normal, log_discretized_normal, \ + log_logistic, log_mixture_normal +from models.backround import _round_straightthrough + + +def compute_log_ps(pxs, xs, args): + # Add likelihoods of intermediate representations. + inverse_bin_width = 2.**args.n_bits + + log_pxs = [] + for px, x in zip(pxs, xs): + + if args.variable_type == 'discrete': + if args.distribution_type == 'logistic': + log_px = log_discretized_logistic( + x, *px, inverse_bin_width=inverse_bin_width) + elif args.distribution_type == 'normal': + log_px = log_discretized_normal( + x, *px, inverse_bin_width=inverse_bin_width) + elif args.variable_type == 'continuous': + if args.distribution_type == 'logistic': + log_px = log_logistic(x, *px) + elif args.distribution_type == 'normal': + log_px = log_normal(x, *px) + elif args.distribution_type == 'steplogistic': + x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width + log_px = log_discretized_logistic( + x, *px, inverse_bin_width=inverse_bin_width) + + log_pxs.append( + torch.sum(log_px, dim=[1, 2, 3])) + + return log_pxs + + +def compute_log_pz(pz, z, args): + inverse_bin_width = 2.**args.n_bits + + if args.variable_type == 'discrete': + if args.distribution_type == 'logistic': + if args.n_mixtures == 1: + log_pz = log_discretized_logistic( + z, pz[0], pz[1], inverse_bin_width=inverse_bin_width) + else: + log_pz = log_mixture_discretized_logistic( + z, pz[0], pz[1], pz[2], + inverse_bin_width=inverse_bin_width) + elif args.distribution_type == 'normal': + log_pz = log_discretized_normal( + z, *pz, inverse_bin_width=inverse_bin_width) + + elif args.variable_type == 'continuous': + if args.distribution_type == 'logistic': + log_pz = log_logistic(z, *pz) + elif args.distribution_type == 'normal': + if args.n_mixtures == 1: + log_pz = log_normal(z, *pz) + else: + log_pz = log_mixture_normal(z, *pz) + elif args.distribution_type == 'steplogistic': + z = _round_straightthrough(z * 256.) / 256. + log_pz = log_discretized_logistic(z, *pz) + + log_pz = torch.sum( + log_pz, + dim=[1, 2, 3]) + + return log_pz + + +def compute_loss_function(pz, z, pys, ys, ldj, args): + """ + Computes the cross entropy loss function while summing over batch dimension, not averaged! + :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param args: global parameter settings + :param beta: beta for kl loss + :return: loss, ce, kl + """ + batch_size = z.size(0) + + # Get array loss, sum over batch + loss_array, bpd_array, bpd_per_prior_array = \ + compute_loss_array(pz, z, pys, ys, ldj, args) + + loss = torch.mean(loss_array) + bpd = torch.mean(bpd_array).item() + bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array] + + return loss, bpd, bpd_per_prior + + +def convert_bpd(log_p, input_size): + return -log_p / (np.prod(input_size) * np.log(2.)) + + +def compute_loss_array(pz, z, pys, ys, ldj, args): + """ + Computes the cross entropy loss function while summing over batch dimension, not averaged! + :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param args: global parameter settings + :param beta: beta for kl loss + :return: loss, ce, kl + """ + bpd_per_prior = [] + + # Likelihood of final representation. + log_pz = compute_log_pz(pz, z, args) + + bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size)) + + log_p = log_pz + + # Add likelihoods of intermediate representations. + if ys: + log_pys = compute_log_ps(pys, ys, args) + + for log_py in log_pys: + log_p += log_py + + bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size)) + + log_p += ldj + + loss = -log_p + bpd = convert_bpd(log_p.detach(), args.input_size) + + return loss, bpd, bpd_per_prior + + +def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args): + return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args) diff --git a/integer_discrete_flows/optimization/training.py b/integer_discrete_flows/optimization/training.py new file mode 100644 index 0000000..b646bd9 --- /dev/null +++ b/integer_discrete_flows/optimization/training.py @@ -0,0 +1,174 @@ +from __future__ import print_function +import torch + +from optimization.loss import calculate_loss +from utils.visual_evaluation import plot_reconstructions + +import numpy as np + + +def train(epoch, train_loader, model, opt, args): + model.train() + train_loss = np.zeros(len(train_loader)) + train_bpd = np.zeros(len(train_loader)) + + num_data = 0 + + for batch_idx, (data, _) in enumerate(train_loader): + data = data.view(-1, *args.input_size) + + data = data.to(args.device) + + opt.zero_grad() + loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data) + + loss = torch.mean(loss) + bpd = torch.mean(bpd) + bpd_per_prior = [torch.mean(i) for i in bpd_per_prior] + + loss.backward() + loss = loss.item() + train_loss[batch_idx] = loss + train_bpd[batch_idx] = bpd + + ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2) + + opt.step() + + num_data += len(data) + + if batch_idx % args.log_interval == 0: + perc = 100. * batch_idx / len(train_loader) + + tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}' + print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj)) + + print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256)) + + print('z bpd: {:.3f}'.format(bpd_per_prior[0])) + for i in range(1, len(bpd_per_prior)): + print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i])) + + print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3))) + print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3))) + if len(pz) == 3: + print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3))) + + for i, py in enumerate(pys): + print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3))) + print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3))) + + from utils.visual_evaluation import plot_images + import os + if not os.path.exists(args.snap_dir + 'training/'): + os.makedirs(args.snap_dir + 'training/') + + print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format( + epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader))) + + return train_loss, train_bpd + + +def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0): + model.eval() + + loss_type = 'bpd' + + def analyse(data_loader, plot=False): + bpds = [] + batch_idx = 0 + with torch.no_grad(): + for data, _ in data_loader: + batch_idx += 1 + + if args.cuda: + data = data.cuda() + + data = data.view(-1, *args.input_size) + + loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \ + model(data) + loss = torch.mean(loss).item() + batch_bpd = torch.mean(batch_bpd).item() + + bpds.append(batch_bpd) + + bpd = np.mean(bpds) + + with torch.no_grad(): + if not testing and plot: + x_sample = model_sample.sample(n=100) + + try: + plot_reconstructions( + x_sample, bpd, loss_type, epoch, args) + except: + print('Not plotting') + + return bpd + + bpd_train = analyse(train_loader) + bpd_val = analyse(val_loader, plot=True) + + with open(file, 'a') as ff: + msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format( + epoch, + bpd_train, + bpd_val) + print(msg, file=ff) + + loss = bpd_val * np.prod(args.input_size) * np.log(2.) + bpd = bpd_val + + file = None + + # Compute log-likelihood + with torch.no_grad(): + if testing: + test_data = val_loader.dataset.data_tensor + + if args.cuda: + test_data = test_data.cuda() + + print('Computing log-likelihood on test set') + + model.eval() + + log_likelihood = analyse(test_data) + + else: + log_likelihood = None + nll_bpd = None + + if file is None: + if testing: + print('====> Test set loss: {:.4f}'.format(loss)) + print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood)) + + print('====> Test set bpd (elbo): {:.4f}'.format(bpd)) + print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/ + (np.prod(args.input_size) * np.log(2.)))) + + else: + print('====> Validation set loss: {:.4f}'.format(loss)) + print('====> Validation set bpd: {:.4f}'.format(bpd)) + else: + with open(file, 'a') as ff: + if testing: + print('====> Test set loss: {:.4f}'.format(loss), file=ff) + print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff) + + print('====> Test set bpd: {:.4f}'.format(bpd), file=ff) + print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood / + (np.prod(args.input_size) * np.log(2.))), + file=ff) + + else: + print('====> Validation set loss: {:.4f}'.format(loss), file=ff) + print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))), + file=ff) + + if not testing: + return loss, bpd + else: + return log_likelihood, nll_bpd diff --git a/integer_discrete_flows/utils/__init__.py b/integer_discrete_flows/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integer_discrete_flows/utils/distributions.py b/integer_discrete_flows/utils/distributions.py new file mode 100644 index 0000000..7cc2381 --- /dev/null +++ b/integer_discrete_flows/utils/distributions.py @@ -0,0 +1,209 @@ +from __future__ import print_function +import torch +import torch.utils.data +import torch.nn.functional as F + +import numpy as np +import math + +MIN_EPSILON = 1e-5 +MAX_EPSILON = 1.-1e-5 + + +PI = math.pi + + +def log_min_exp(a, b, epsilon=1e-8): + """ + Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion. + Using: + log(exp(a) - exp(b)) + c + log(exp(a-c) - exp(b-c)) + a + log(1 - exp(b-a)) + And note that we assume b < a always. + """ + y = a + torch.log(1 - torch.exp(b - a) + epsilon) + + return y + + +def log_normal(x, mean, logvar): + logp = -0.5 * logvar + logp += -0.5 * np.log(2 * PI) + logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar) + return logp + + +def log_mixture_normal(x, mean, logvar, pi): + x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) + + logp_mixtures = log_normal(x, mean, logvar) + + logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8) + + return logp + + +def sample_normal(mean, logvar): + y = torch.randn_like(mean) + + x = torch.exp(0.5 * logvar) * y + mean + + return x + + +def sample_mixture_normal(mean, logvar, pi): + b, c, h, w, n_mixtures = tuple(map(int, pi.size())) + pi = pi.view(b * c * h * w, n_mixtures) + sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) + + # Select mixture params + mean = mean.view(b * c * h * w, n_mixtures) + mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) + logvar = logvar.view(b * c * h * w, n_mixtures) + logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) + + y = sample_normal(mean, logvar) + + return y + + +def log_logistic(x, mean, logscale): + """ + pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale + """ + scale = torch.exp(logscale) + + u = (x - mean) / scale + + logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale + + return logp + + +def sample_logistic(mean, logscale): + y = torch.rand_like(mean) + + x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean + + return x + + +def log_discretized_logistic(x, mean, logscale, inverse_bin_width): + scale = torch.exp(logscale) + + logp = log_min_exp( + F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale), + F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale)) + + return logp + + +def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width): + scale = torch.exp(logscale) + + cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) + + return cdf + + +def sample_discretized_logistic(mean, logscale, inverse_bin_width): + x = sample_logistic(mean, logscale) + + x = torch.round(x * inverse_bin_width) / inverse_bin_width + return x + + +def normal_cdf(value, loc, std): + return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2))) + + +def log_discretized_normal(x, mean, logvar, inverse_bin_width): + std = torch.exp(0.5 * logvar) + log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7) + + return log_p + + +def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width): + std = torch.exp(0.5 * logvar) + + x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) + + p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + + p = torch.sum(p * pi, dim=-1) + + logp = torch.log(p + 1e-8) + + return logp + + +def sample_discretized_normal(mean, logvar, inverse_bin_width): + y = torch.randn_like(mean) + + x = torch.exp(0.5 * logvar) * y + mean + + x = torch.round(x * inverse_bin_width) / inverse_bin_width + + return x + + +def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width): + scale = torch.exp(logscale) + + x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) + + p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \ + - torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale) + + p = torch.sum(p * pi, dim=-1) + + logp = torch.log(p + 1e-8) + + return logp + + +def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width): + scale = torch.exp(logscale) + + x = x[..., None] + + cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) + + cdf = torch.sum(cdfs * pi, dim=-1) + + return cdf + + +def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width): + # Sample mixtures + b, c, h, w, n_mixtures = tuple(map(int, pi.size())) + pi = pi.view(b * c * h * w, n_mixtures) + sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) + + # Select mixture params + mean = mean.view(b * c * h * w, n_mixtures) + mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) + logs = logs.view(b * c * h * w, n_mixtures) + logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) + + y = torch.rand_like(mean) + x = torch.exp(logs) * torch.log(y / (1 - y)) + mean + + x = torch.round(x * inverse_bin_width) / inverse_bin_width + + return x + + +def log_multinomial(logits, targets): + return -F.cross_entropy(logits, targets, reduction='none') + + +def sample_multinomial(logits): + b, n_categories, c, h, w = logits.size() + logits = logits.permute(0, 2, 3, 4, 1) + p = F.softmax(logits, dim=-1) + p = p.view(b * c * h * w, n_categories) + x = torch.multinomial(p, num_samples=1).view(b, c, h, w) + return x diff --git a/integer_discrete_flows/utils/load_data.py b/integer_discrete_flows/utils/load_data.py new file mode 100644 index 0000000..b743e69 --- /dev/null +++ b/integer_discrete_flows/utils/load_data.py @@ -0,0 +1,264 @@ +from __future__ import print_function + +import numbers + +import torch +import torch.utils.data as data_utils +import pickle +from scipy.io import loadmat + +import numpy as np + +import os +from torch.utils.data import Dataset +import torchvision +from torchvision import transforms +from torchvision.transforms import functional as vf +from torch.utils.data import ConcatDataset + +from PIL import Image + +import os +import os.path +from os.path import join +import sys +import tarfile + + +class ToTensorNoNorm(): + def __call__(self, X_i): + return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1) + + +class PadToMultiple(object): + def __init__(self, multiple, fill=0, padding_mode='constant'): + assert isinstance(multiple, numbers.Number) + assert isinstance(fill, (numbers.Number, str, tuple)) + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + + self.multiple = multiple + self.fill = fill + self.padding_mode = padding_mode + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be padded. + Returns: + PIL Image: Padded image. + """ + w, h = img.size + m = self.multiple + nw = (w // m + int((w % m) != 0)) * m + nh = (h // m + int((h % m) != 0)) * m + padw = nw - w + padh = nh - h + + out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode) + return out + + def __repr__(self): + return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\ + format(self.mulitple, self.fill, self.padding_mode) + + +class CustomTensorDataset(Dataset): + """Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Arguments: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + def __init__(self, *tensors, transform=None): + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) + self.tensors = tensors + self.transform = transform + + def __getitem__(self, index): + from PIL import Image + + X, y = self.tensors + X_i, y_i, = X[index], y[index] + + if self.transform: + X_i = self.transform(X_i) + X_i = torch.from_numpy(np.array(X_i, copy=False)) + X_i = X_i.permute(2, 0, 1) + + return X_i, y_i + + def __len__(self): + return self.tensors[0].size(0) + + +def load_cifar10(args, **kwargs): + # set args + args.input_size = [3, 32, 32] + args.input_type = 'continuous' + args.dynamic_binarization = False + + from keras.datasets import cifar10 + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + x_train = x_train.transpose(0, 3, 1, 2) + x_test = x_test.transpose(0, 3, 1, 2) + + import math + + if args.data_augmentation_level == 2: + data_transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'), + transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), + transforms.CenterCrop(32) + ]) + elif args.data_augmentation_level == 1: + data_transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + ]) + else: + data_transform = transforms.Compose([ + transforms.ToPILImage(), + ]) + + x_val = x_train[-10000:] + y_val = y_train[-10000:] + + x_train = x_train[:-10000] + y_train = y_train[:-10000] + + train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def extract_tar(tarpath): + assert tarpath.endswith('.tar') + + startdir = tarpath[:-4] + '/' + + if os.path.exists(startdir): + return startdir + + print('Extracting', tarpath) + + with tarfile.open(name=tarpath) as tar: + t = 0 + done = False + while not done: + path = join(startdir, 'images{}'.format(t)) + os.makedirs(path, exist_ok=True) + + print(path) + + for i in range(50000): + member = tar.next() + + if member is None: + done = True + break + + # Skip directories + while member.isdir(): + member = tar.next() + if member is None: + done = True + break + + member.name = member.name.split('/')[-1] + + tar.extract(member, path=path) + + t += 1 + + return startdir + + +def load_imagenet(resolution, args, **kwargs): + assert resolution == 32 or resolution == 64 + + args.input_size = [3, resolution, resolution] + + trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution) + valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution) + + trainpath = extract_tar(trainpath) + valpath = extract_tar(valpath) + + data_transform = transforms.Compose([ + ToTensorNoNorm() + ]) + + print('Starting loading ImageNet') + + imagenet_data = torchvision.datasets.ImageFolder( + trainpath, + transform=data_transform) + + print('Number of data images', len(imagenet_data)) + + val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False) + train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs) + + train_dataset = torch.utils.data.dataset.Subset( + imagenet_data, train_idcs) + val_dataset = torch.utils.data.dataset.Subset( + imagenet_data, val_idcs) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + **kwargs) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + **kwargs) + + test_dataset = torchvision.datasets.ImageFolder( + valpath, + transform=data_transform) + + print('Number of val images:', len(test_dataset)) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_dataset(args, **kwargs): + + if args.dataset == 'cifar10': + train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs) + elif args.dataset == 'imagenet32': + train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs) + elif args.dataset == 'imagenet64': + train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs) + else: + raise Exception('Wrong name of the dataset!') + + return train_loader, val_loader, test_loader, args + + +if __name__ == '__main__': + class Args(): + def __init__(self): + self.batch_size = 128 + train_loader, val_loader, test_loader, args = load_imagenet32(Args()) diff --git a/integer_discrete_flows/utils/log_likelihood.py b/integer_discrete_flows/utils/log_likelihood.py new file mode 100644 index 0000000..4f47e7c --- /dev/null +++ b/integer_discrete_flows/utils/log_likelihood.py @@ -0,0 +1,56 @@ +from __future__ import print_function +import numpy as np +from scipy.misc import logsumexp +from optimization.loss import calculate_loss_array + + +def calculate_likelihood(X, model, args, S=5000, MB=500): + + # set auxiliary variables for number of training and test sets + N_test = X.size(0) + + X = X.view(-1, *args.input_size) + + likelihood_test = [] + + if S <= MB: + R = 1 + else: + R = S // MB + S = MB + + for j in range(N_test): + if j % 100 == 0: + print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100)) + + x_single = X[j].unsqueeze(0) + + a = [] + for r in range(0, R): + # Repeat it for all training points + x = x_single.expand(S, *x_single.size()[1:]).contiguous() + + x_mean, z_mu, z_var, ldj, z0, zk = model(x) + + a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args) + + a.append(-a_tmp.cpu().data.numpy()) + + # calculate max + a = np.asarray(a) + a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) + likelihood_x = logsumexp(a) + likelihood_test.append(likelihood_x - np.log(len(a))) + + likelihood_test = np.array(likelihood_test) + + nll = -np.mean(likelihood_test) + + if args.input_type == 'multinomial': + bpd = nll/(np.prod(args.input_size) * np.log(2.)) + elif args.input_type == 'binary': + bpd = 0. + else: + raise ValueError('invalid input type!') + + return nll, bpd diff --git a/integer_discrete_flows/utils/plotting.py b/integer_discrete_flows/utils/plotting.py new file mode 100644 index 0000000..6259834 --- /dev/null +++ b/integer_discrete_flows/utils/plotting.py @@ -0,0 +1,104 @@ +from __future__ import division +from __future__ import print_function + +import numpy as np +import matplotlib +# noninteractive background +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None): + """ + Plots train_loss and validation loss as a function of optimization iteration + :param train_loss: np.array of train_loss (1D or 2D) + :param validation_loss: np.array of validation loss (1D or 2D) + :param fname: output file name + :param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied + accross training curves. + :return: None + """ + + plt.close() + + matplotlib.rcParams.update({'font.size': 14}) + matplotlib.rcParams['mathtext.fontset'] = 'stix' + matplotlib.rcParams['font.family'] = 'STIXGeneral' + + if len(train_loss.shape) == 1: + # Single training curve + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if train_loss.shape[0] == validation_loss.shape[0]: + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + + elif train_loss.shape[0] % validation_loss.shape[0] == 0: + # validation score evaluated every epoch + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + + x = np.arange(validation_loss.shape[0]) + x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0] + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + else: + raise ValueError('Length of train_loss and validation_loss must be equal or divisible') + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + elif len(train_loss.shape) == 2: + # Multiple training curves + + cmap = plt.cm.brg + + cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0]) + scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap) + + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if labels is None: + labels = ['%d' % i for i in range(train_loss.shape[0])] + + if train_loss.shape[1] == validation_loss.shape[1]: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + ax.plot(x, validation_loss[i], '--', lw=2., color=color_val) + + elif train_loss.shape[1] % validation_loss.shape[1] == 0: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every epoch + x = np.arange(train_loss.shape[1]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + + x = np.arange(validation_loss.shape[1]) + x = (x+1) * train_loss.shape[1] / validation_loss.shape[1] + ax.plot(x, validation_loss[i], '-', lw=2., color=color_val) + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + else: + raise ValueError('train_loss and validation_loss must be 1D or 2D arrays') + + ax.set_xlabel('iteration') + ax.set_ylabel('loss') + plt.title('Training and validation loss') + + fig.set_size_inches(figsize) + fig.subplots_adjust(hspace=0.1) + plt.savefig(fname, bbox_inches='tight') + + plt.close() diff --git a/integer_discrete_flows/utils/visual_evaluation.py b/integer_discrete_flows/utils/visual_evaluation.py new file mode 100644 index 0000000..340fb69 --- /dev/null +++ b/integer_discrete_flows/utils/visual_evaluation.py @@ -0,0 +1,37 @@ +from __future__ import print_function +import os +import numpy as np +import imageio + + +def plot_reconstructions(recon_mean, loss, loss_type, epoch, args): + if epoch == 1: + if not os.path.exists(args.snap_dir + 'reconstruction/'): + os.makedirs(args.snap_dir + 'reconstruction/') + if loss_type == 'bpd': + fname = str(epoch) + '_bpd_%5.3f' % loss + elif loss_type == 'elbo': + fname = str(epoch) + '_elbo_%6.4f' % loss + plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname) + + +def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10): + batch, channels, height, width = x_sample.shape + + print(x_sample.shape) + + mosaic = np.zeros((height * size_y, width * size_x, channels)) + + for j in range(size_y): + for i in range(size_x): + idx = j * size_x + i + + image = x_sample[idx] + + mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \ + image.transpose(1, 2, 0) + + # Remove channel for BW images + mosaic = mosaic.squeeze() + + imageio.imwrite(dir + file_name + '.png', mosaic)