diff --git a/CNN-model/dataset_loaders/Dataset.py b/dataset_loaders/Dataset.py similarity index 100% rename from CNN-model/dataset_loaders/Dataset.py rename to dataset_loaders/Dataset.py diff --git a/CNN-model/dataset_loaders/EnWik9.py b/dataset_loaders/EnWik9.py similarity index 100% rename from CNN-model/dataset_loaders/EnWik9.py rename to dataset_loaders/EnWik9.py diff --git a/CNN-model/dataset_loaders/LoremIpsumDataset.py b/dataset_loaders/LoremIpsumDataset.py similarity index 100% rename from CNN-model/dataset_loaders/LoremIpsumDataset.py rename to dataset_loaders/LoremIpsumDataset.py diff --git a/CNN-model/dataset_loaders/OpenGenomeDataset.py b/dataset_loaders/OpenGenomeDataset.py similarity index 100% rename from CNN-model/dataset_loaders/OpenGenomeDataset.py rename to dataset_loaders/OpenGenomeDataset.py diff --git a/CNN-model/dataset_loaders/__init__.py b/dataset_loaders/__init__.py similarity index 100% rename from CNN-model/dataset_loaders/__init__.py rename to dataset_loaders/__init__.py diff --git a/integer_discrete_flows/LICENSE b/integer_discrete_flows/LICENSE deleted file mode 100644 index 4784841..0000000 --- a/integer_discrete_flows/LICENSE +++ /dev/null @@ -1,19 +0,0 @@ -Copyright (c) 2019 Emiel Hoogeboom, Jorn Peters, Rianne van den Berg, Max Welling - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/integer_discrete_flows/README.md b/integer_discrete_flows/README.md deleted file mode 100644 index fa0a33e..0000000 --- a/integer_discrete_flows/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Integer Discrete Flows and Lossless Compression - -This repository contains the code for the experiments presented in [1]. - -## Usage - -### CIFAR10 setup: -``` -python main_experiment.py --n_flows 8 --n_levels 3 --n_channels 512 --coupling_type 'densenet' --densenet_depth 12 —n_mixtures 5 —splitprior_type ‘densenet’ -``` - - -### ImageNet32 setup: -``` -python main_experiment.py --evaluate_interval_epochs 5 --n_flows 8 --n_levels 3 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet32' --epochs 100 --lr_decay 0.99 -``` - - -### ImageNet64 setup: -``` -python main_experiment.py --evaluate_interval_epochs 1 --n_flows 8 --n_levels 4 --n_channels 512 --n_mixtures 5 --densenet_depth 12 --coupling_type 'densenet' --splitprior_type 'densenet' --dataset 'imagenet64' --epochs 20 --lr_decay 0.99 --batch_size 64 -``` - -# Acknowledgements -The Robert Bosch GmbH is acknowledged for financial support. - -# References -[1] Hoogeboom, Emiel, Jorn WT Peters, Rianne van den Berg, and Max Welling. "Integer Discrete Flows and Lossless Compression." Conference on Neural Information Processing Systems (2019). - diff --git a/integer_discrete_flows/coding/__init__.py b/integer_discrete_flows/coding/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/integer_discrete_flows/coding/coder.py b/integer_discrete_flows/coding/coder.py deleted file mode 100644 index 1303006..0000000 --- a/integer_discrete_flows/coding/coder.py +++ /dev/null @@ -1,132 +0,0 @@ -import numpy as np -from . import rans -from utils.distributions import discretized_logistic_cdf, \ - mixture_discretized_logistic_cdf -import torch - -precision = 24 -n_bins = 4096 - - -def cdf_fn(z, pz, variable_type, distribution_type, inverse_bin_width): - if variable_type == 'discrete': - if distribution_type == 'logistic': - if len(pz) == 2: - return discretized_logistic_cdf( - z, *pz, inverse_bin_width=inverse_bin_width) - elif len(pz) == 3: - return mixture_discretized_logistic_cdf( - z, *pz, inverse_bin_width=inverse_bin_width) - elif distribution_type == 'normal': - pass - - elif variable_type == 'continuous': - if distribution_type == 'logistic': - pass - elif distribution_type == 'normal': - pass - elif distribution_type == 'steplogistic': - pass - raise ValueError - - -def CDF_fn(pz, bin_width, variable_type, distribution_type): - mean = pz[0] if len(pz) == 2 else pz[0][..., (pz[0].size(-1) - 1) // 2] - MEAN = torch.round(mean / bin_width).long() - - bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None] - bin_locations = bin_locations.float() * bin_width - bin_locations = bin_locations.to(device=pz[0].DEVICE) - - pz = [param[:, :, :, :, None] for param in pz] - cdf = cdf_fn( - bin_locations - bin_width, - pz, - variable_type, - distribution_type, - 1./bin_width).cpu().numpy() - - # Compute CDFs, reweigh to give all bins at least - # 1 / (2^precision) probability. - # CDF is equal to floor[cdf * (2^precision - n_bins)] + range(n_bins) - CDFs = (cdf * ((1 << precision) - n_bins)).astype('int') \ - + np.arange(n_bins) - - return CDFs, MEAN - - -def encode_sample( - z, pz, variable_type, distribution_type, bin_width=1./256, state=None): - if state is None: - state = rans.x_init - else: - state = rans.unflatten(state) - - CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type) - - # z is transformed to Z to match the indices for the CDFs array - Z = torch.round(z / bin_width).long() + n_bins // 2 - MEAN - Z = Z.cpu().numpy() - - if not ((np.sum(Z < 0) == 0 and np.sum(Z >= n_bins-1) == 0)): - print('Z out of allowed range of values, canceling compression') - return None - - Z, CDFs = Z.reshape(-1), CDFs.reshape(-1, n_bins).copy() - for symbol, cdf in zip(Z[::-1], CDFs[::-1]): - statfun = statfun_encode(cdf) - state = rans.append_symbol(statfun, precision)(state, symbol) - - state = rans.flatten(state) - - return state - - -def decode_sample( - state, pz, variable_type, distribution_type, bin_width=1./256): - state = rans.unflatten(state) - - device = pz[0].DEVICE - size = pz[0].size()[0:4] - - CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type) - - CDFs = CDFs.reshape(-1, n_bins) - result = np.zeros(len(CDFs), dtype=int) - for i, cdf in enumerate(CDFs): - statfun = statfun_decode(cdf) - state, symbol = rans.pop_symbol(statfun, precision)(state) - result[i] = symbol - - Z_flat = torch.from_numpy(result).to(device) - Z = Z_flat.view(size) - n_bins // 2 + MEAN - - z = Z.float() * bin_width - - state = rans.flatten(state) - - return state, z - - -def statfun_encode(CDF): - def _statfun_encode(symbol): - return CDF[symbol], CDF[symbol + 1] - CDF[symbol] - return _statfun_encode - - -def statfun_decode(CDF): - def _statfun_decode(cf): - # Search such that CDF[s] <= cf < CDF[s] - s = np.searchsorted(CDF, cf, side='right') - s = s - 1 - start, freq = statfun_encode(CDF)(s) - return s, (start, freq) - return _statfun_decode - - -def encode(x, symbol): - return rans.append_symbol(statfun_encode, precision)(x, symbol) - - -def decode(x): - return rans.pop_symbol(statfun_decode, precision)(x) diff --git a/integer_discrete_flows/coding/rans.py b/integer_discrete_flows/coding/rans.py deleted file mode 100644 index 02583a5..0000000 --- a/integer_discrete_flows/coding/rans.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h - -ROUGH GUIDE: -We use the pythonic names 'append' and 'pop' for encoding and decoding -respectively. The compressed state 'x' is an immutable stack, implemented using -a cons list. - -x: the current stack-like state of the encoder/decoder. - -precision: the natural numbers are divided into ranges of size 2^precision. - -start & freq: start indicates the beginning of the range in [0, 2^precision-1] -that the current symbol is represented by. freq is the length of the range. -freq is chosen such that p(symbol) ~= freq/2^precision. -""" -import numpy as np -from functools import reduce - - -rans_l = 1 << 31 # the lower bound of the normalisation interval -tail_bits = (1 << 32) - 1 - -x_init = (rans_l, ()) - -def append(x, start, freq, precision): - """Encodes a symbol with range [start, start + freq). All frequencies are - assumed to sum to "1 << precision", and the resulting bits get written to - x.""" - if x[0] >= ((rans_l >> precision) << 32) * freq: - x = (x[0] >> 32, (x[0] & tail_bits, x[1])) - return ((x[0] // freq) << precision) + (x[0] % freq) + start, x[1] - -def pop(x_, precision): - """Advances in the bit stream by "popping" a single symbol with range start - "start" and frequency "freq".""" - cf = x_[0] & ((1 << precision) - 1) - def pop(start, freq): - x = freq * (x_[0] >> precision) + cf - start, x_[1] - return ((x[0] << 32) | x[1][0], x[1][1]) if x[0] < rans_l else x - return cf, pop - -def append_symbol(statfun, precision): - def append_(x, symbol): - start, freq = statfun(symbol) - return append(x, start, freq, precision) - return append_ - -def pop_symbol(statfun, precision): - def pop_(x): - cf, pop_fun = pop(x, precision) - symbol, (start, freq) = statfun(cf) - return pop_fun(start, freq), symbol - return pop_ - -def flatten(x): - """Flatten a rans state x into a 1d numpy array.""" - out, x = [x[0] >> 32, x[0]], x[1] - while x: - x_head, x = x - out.append(x_head) - return np.asarray(out, dtype=np.uint32) - -def unflatten(arr): - """Unflatten a 1d numpy array into a rans state.""" - return (int(arr[0]) << 32 | int(arr[1]), - reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ())) diff --git a/integer_discrete_flows/experiment_coding.py b/integer_discrete_flows/experiment_coding.py deleted file mode 100644 index 1d0cd75..0000000 --- a/integer_discrete_flows/experiment_coding.py +++ /dev/null @@ -1,188 +0,0 @@ -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -from __future__ import print_function -import argparse -import torch -import torch.utils.data -import numpy as np - -from utils.load_data import load_dataset - - -parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') - -parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'], - metavar='DATASET', - help='Dataset choice.') - -parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, - help='disables CUDA training') - -parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') - -parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL', - help='how many batches to wait before logging training status') - -parser.add_argument('--evaluate_interval_epochs', type=int, default=25, - help='Evaluate per how many epochs') - - -# optimization settings -parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS', - help='number of epochs to train (default: 2000)') -parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING', - help='number of early stopping epochs') - -parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE', - help='input batch size for training (default: 100)') -parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE', - help='learning rate') -parser.add_argument('--warmup', type=int, default=10, - help='number of warmup epochs') - -parser.add_argument('--data_augmentation_level', type=int, default=2, - help='data augmentation level') - -parser.add_argument('--no_decode', action='store_true', default=False, - help='disables decoding') - - -args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() - -kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} - - -def encode_images(img, model, decode): - batchsize, img_c, img_h, img_w = img.size() - c, h, w = model.args.input_size - - assert img_h == img_w and h == w - - if img_h != h: - assert img_h % h == 0 - steps = img_h // h - - states = [[] for i in range(batchsize)] - state_sizes = [0 for i in range(batchsize)] - bpd = [0 for i in range(batchsize)] - error = 0 - - for j in range(steps): - for i in range(steps): - r = encode_patches( - img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode) - for b in range(batchsize): - - if r[0][b] is None: - states[b].append(None) - else: - states[b].extend(r[0][b]) - state_sizes[b] += r[1][b] - bpd[b] += r[2][b] / steps**2 - error += r[3] - return states, state_sizes, bpd, error - else: - return encode_patches(img, model, decode) - - -def encode_patches(imgs, model, decode): - batchsize, img_c, img_h, img_w = imgs.size() - c, h, w = model.args.input_size - assert img_h == h and img_w == w - - states = model.encode(imgs) - - bpd = model.forward(imgs)[1].cpu().numpy() - - state_sizes = [] - error = 0 - - for b in range(batchsize): - if states[b] is None: - # Using escape bit ;) - state_sizes += [8 * img_c * img_h * img_w + 1] - - # Error remains unchanged. - print('Escaping, not encoding.') - - else: - if decode: - x_recon = model.decode([states[b]]) - - error += torch.sum( - torch.abs(x_recon.int() - imgs[b].int())).item() - - # Append state plus an escape bit - state_sizes += [32 * len(states[b]) + 1] - - return states, state_sizes, bpd, error - - -def run(args, kwargs): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - args.snap_dir = snap_dir = \ - 'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/' - - # ================================================================================================================== - # SNAPSHOTS - # ================================================================================================================== - - # ================================================================================================================== - # LOAD DATA - # ================================================================================================================== - train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) - - final_model = torch.load(snap_dir + 'a.model') - if hasattr(final_model, 'module'): - final_model = final_model.module - final_model = final_model.cuda() - - sizes = [] - errors = [] - bpds = [] - - import time - start = time.time() - - t = 0 - with torch.no_grad(): - for data, _ in test_loader: - if args.cuda: - data = data.cuda() - - state, state_sizes, bpd, error = \ - encode_images(data, final_model, decode=not args.no_decode) - - errors += [error] - bpds.extend(bpd) - sizes.extend(state_sizes) - - t += len(data) - - print( - 'Examples: {}/{} bpd compression: {:.3f} error: {},' - ' analytical bpd {:.3f}'.format( - t, len(test_loader.dataset), - np.mean(sizes) / np.prod(data.size()[1:]), - np.sum(errors), - np.mean(bpds) - )) - - if args.no_decode: - print('Not testing decoding.') - else: - print('Error: {}'.format(np.sum(errors))) - - print('Took {:.3f} seconds / example'.format((time.time() - start) / t)) - print('Final bpd: {:.3f} error: {}'.format( - np.mean(sizes) / np.prod(data.size()[1:]), - np.sum(errors))) - - -if __name__ == "__main__": - - run(args, kwargs) diff --git a/integer_discrete_flows/experiment_progressive_loading.py b/integer_discrete_flows/experiment_progressive_loading.py deleted file mode 100644 index cbd01cd..0000000 --- a/integer_discrete_flows/experiment_progressive_loading.py +++ /dev/null @@ -1,105 +0,0 @@ -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -from __future__ import print_function -import argparse -import time -import torch -import torch.utils.data -import torch.optim as optim -import numpy as np -import matplotlib.pyplot as plt -from torchvision.utils import make_grid - -import os -from optimization.training import train, evaluate -from utils.load_data import load_dataset -from utils.plotting import plot_training_curve -import imageio - - -parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') - -parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'], - metavar='DATASET', - help='Dataset choice.') - -parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE', - help='input batch size for training (default: 100)') - -parser.add_argument('--data_augmentation_level', type=int, default=2, - help='data augmentation level') - -parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, - help='disables CUDA training') - - -args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() - -kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} - - -def run(args, kwargs): - - args.snap_dir = snap_dir = \ - 'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/' - - # ================================================================================================================== - # SNAPSHOTS - # ================================================================================================================== - - # ================================================================================================================== - # LOAD DATA - # ================================================================================================================== - train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) - - final_model = torch.load(snap_dir + 'a.model') - if hasattr(final_model, 'module'): - final_model = final_model.module - - from models.backround import SmoothRound - for module in final_model.modules(): - if isinstance(module, SmoothRound): - module._round_decay = 1. - - exp_dir = snap_dir + 'partials/' - os.makedirs(exp_dir, exist_ok=True) - - images = [] - with torch.no_grad(): - for data, _ in test_loader: - - if args.cuda: - data = data.cuda() - - for i in range(len(data)): - _, _, _, pz, z, pys, ys, ldj = final_model.forward(data[i:i+1]) - - for j in range(len(ys) + 1): - x_recon = final_model.inverse( - z, - ys[len(ys) - j:]) - - images.append(x_recon.float()) - - if i == 10: - break - break - - for j in range(len(ys) + 1): - - grid = make_grid( - torch.stack(images[j::len(ys) + 1], dim=0).squeeze(), - nrow=11, padding=0, - normalize=True, range=None, - scale_each=False, pad_value=0) - - imageio.imwrite( - exp_dir + 'loaded{j}.png'.format(j=j), - grid.cpu().numpy().transpose(1, 2, 0)) - - -if __name__ == "__main__": - - run(args, kwargs) diff --git a/integer_discrete_flows/main_experiment.py b/integer_discrete_flows/main_experiment.py deleted file mode 100644 index 9f5acfc..0000000 --- a/integer_discrete_flows/main_experiment.py +++ /dev/null @@ -1,285 +0,0 @@ -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -from __future__ import print_function -import argparse -import time -import torch -import torch.utils.data -import torch.optim as optim -import numpy as np -import math -import random - -import os - -import datetime - -from optimization.training import train, evaluate -from utils.load_data import load_dataset - -parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows') - -parser.add_argument('-d', '--dataset', type=str, default='cifar10', - choices=['cifar10', 'imagenet32', 'imagenet64'], - metavar='DATASET', - help='Dataset choice.') - -parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, - help='disables CUDA training') - -parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') - -parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL', - help='how many batches to wait before logging training status') - -parser.add_argument('--evaluate_interval_epochs', type=int, default=25, - help='Evaluate per how many epochs') - -parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR', - help='output directory for model snapshots etc.') - -fp = parser.add_mutually_exclusive_group(required=False) -fp.add_argument('-te', '--testing', action='store_true', dest='testing', - help='evaluate on test set after training') -fp.add_argument('-va', '--validation', action='store_false', dest='testing', - help='only evaluate on validation set') -parser.set_defaults(testing=True) - -# optimization settings -parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS', - help='number of epochs to train (default: 2000)') -parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING', - help='number of early stopping epochs') - -parser.add_argument('-bs', '--batch_size', type=int, default=256, metavar='BATCH_SIZE', - help='input batch size for training (default: 100)') -parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE', - help='learning rate') -parser.add_argument('--warmup', type=int, default=10, - help='number of warmup epochs') - -parser.add_argument('--data_augmentation_level', type=int, default=2, - help='data augmentation level') - -parser.add_argument('--variable_type', type=str, default='discrete', - help='variable type of data distribution: discrete/continuous', - choices=['discrete', 'continuous']) -parser.add_argument('--distribution_type', type=str, default='logistic', - choices=['logistic', 'normal', 'steplogistic'], - help='distribution type: logistic/normal') -parser.add_argument('--n_flows', type=int, default=8, - help='number of flows per level') -parser.add_argument('--n_levels', type=int, default=3, - help='number of levels') - -parser.add_argument('--n_bits', type=int, default=8, - help='') - -# ---------------- SETTINGS CONCERNING NETWORKS ------------- -parser.add_argument('--densenet_depth', type=int, default=8, - help='Depth of densenets') -parser.add_argument('--n_channels', type=int, default=512, - help='number of channels in coupling and splitprior') -# ---------------- ----------------------------- ------------- - - -# ---------------- SETTINGS CONCERNING COUPLING LAYERS ------------- -parser.add_argument('--coupling_type', type=str, default='shallow', - choices=['shallow', 'resnet', 'densenet'], - help='Type of coupling layer') -parser.add_argument('--splitfactor', default=0, type=int, - help='Split factor for coupling layers.') - -parser.add_argument('--split_quarter', dest='split_quarter', action='store_true', - help='Split coupling layer on quarter') -parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false') -parser.set_defaults(split_quarter=True) -# ---------------- ----------------------------------- ------------- - - -# ---------------- SETTINGS CONCERNING SPLITPRIORS ------------- -parser.add_argument('--splitprior_type', type=str, default='shallow', - choices=['none', 'shallow', 'resnet', 'densenet'], - help='Type of splitprior. Use \'none\' for no splitprior') -# ---------------- ------------------------------- ------------- - - -# ---------------- SETTINGS CONCERNING PRIORS ------------- -parser.add_argument('--n_mixtures', type=int, default=1, - help='number of mixtures') -# ---------------- ------------------------------- ------------- - -parser.add_argument('--hard_round', dest='hard_round', action='store_true', - help='Rounding of translation in discrete models. Weird ' - 'probabilistic implications, only for experimental phase') -parser.add_argument('--no_hard_round', dest='hard_round', action='store_false') -parser.set_defaults(hard_round=True) - -parser.add_argument('--round_approx', type=str, default='smooth', - choices=['smooth', 'stochastic']) - -parser.add_argument('--lr_decay', default=0.999, type=float, - help='Learning rate') - -parser.add_argument('--temperature', default=1.0, type=float, - help='Temperature used for BackRound. It is used in ' - 'the the SmoothRound module. ' - '(default=1.0') - -# gpu/cpu -parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU', - help='choose GPU to run on.') - -args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() - -if args.manual_seed is None: - args.manual_seed = random.randint(1, 100000) -random.seed(args.manual_seed) -torch.manual_seed(args.manual_seed) -np.random.seed(args.manual_seed) - - -kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} - - -def run(args, kwargs): - - print('\nMODEL SETTINGS: \n', args, '\n') - print("Random Seed: ", args.manual_seed) - - if 'imagenet' in args.dataset and args.evaluate_interval_epochs > 5: - args.evaluate_interval_epochs = 5 - - # ================================================================================================================== - # SNAPSHOTS - # ================================================================================================================== - args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') - args.model_signature = args.model_signature.replace(':', '_') - - snapshots_path = os.path.join(args.out_dir, args.variable_type + '_' + args.distribution_type + args.dataset) - snap_dir = snapshots_path - - snap_dir += '_' + 'flows_' + str(args.n_flows) + '_levels_' + str(args.n_levels) - - snap_dir = snap_dir + '__' + args.model_signature + '/' - - args.snap_dir = snap_dir - - if not os.path.exists(snap_dir): - os.makedirs(snap_dir) - - with open(snap_dir + 'log.txt', 'a') as ff: - print('\nMODEL SETTINGS: \n', args, '\n', file=ff) - - # SAVING - torch.save(args, snap_dir + '.config') - - # ================================================================================================================== - # LOAD DATA - # ================================================================================================================== - train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) - - # ================================================================================================================== - # SELECT MODEL - # ================================================================================================================== - # flow parameters and architecture choice are passed on to model through args - print(args.input_size) - - import models.Model as Model - - model = Model.Model(args) - args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model.set_temperature(args.temperature) - model.enable_hard_round(args.hard_round) - - model_sample = model - - # ==================================== - # INIT - # ==================================== - # data dependend initialization on CPU - for batch_idx, (data, _) in enumerate(train_loader): - model(data) - break - - if torch.cuda.device_count() > 1: - print("Let's use", torch.cuda.device_count(), "GPUs!") - model = torch.nn.DataParallel(model, dim=0) - - model.to(args.DEVICE) - - def lr_lambda(epoch): - return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch) - optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7) - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) - - # ================================================================================================================== - # TRAINING - # ================================================================================================================== - train_bpd = [] - val_bpd = [] - - # for early stopping - best_val_bpd = np.inf - best_train_bpd = np.inf - epoch = 0 - - train_times = [] - - model.eval() - model.train() - - for epoch in range(1, args.epochs + 1): - t_start = time.time() - scheduler.step() - tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args) - train_bpd.append(tr_bpd) - train_times.append(time.time()-t_start) - print('One training epoch took %.2f seconds' % (time.time()-t_start)) - - if epoch < 25 or epoch % args.evaluate_interval_epochs == 0: - v_loss, v_bpd = evaluate( - train_loader, val_loader, model, model_sample, args, - epoch=epoch, file=snap_dir + 'log.txt') - - val_bpd.append(v_bpd) - - # Model save based on TRAIN performance (is heavily correlated with validation performance.) - if np.mean(tr_bpd) < best_train_bpd: - best_train_bpd = np.mean(tr_bpd) - best_val_bpd = v_bpd - torch.save(model.module, snap_dir + 'a.model') - torch.save(optimizer, snap_dir + 'a.optimizer') - print('->model saved<-') - - print('(BEST: train bpd {:.4f}, test bpd {:.4f})\n'.format( - best_train_bpd, best_val_bpd)) - - if math.isnan(v_loss): - raise ValueError('NaN encountered!') - - train_bpd = np.hstack(train_bpd) - val_bpd = np.array(val_bpd) - - # training time per epoch - train_times = np.array(train_times) - mean_train_time = np.mean(train_times) - std_train_time = np.std(train_times, ddof=1) - print('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) - - # ================================================================================================================== - # EVALUATION - # ================================================================================================================== - final_model = torch.load(snap_dir + 'a.model') - test_loss, test_bpd = evaluate( - train_loader, test_loader, final_model, final_model, args, - epoch=epoch, file=snap_dir + 'test_log.txt') - - print('Test loss / bpd: %.2f / %.2f' % (test_loss, test_bpd)) - - -if __name__ == "__main__": - - run(args, kwargs) diff --git a/integer_discrete_flows/media/IDF_poster.pdf b/integer_discrete_flows/media/IDF_poster.pdf deleted file mode 100644 index a1c512d..0000000 Binary files a/integer_discrete_flows/media/IDF_poster.pdf and /dev/null differ diff --git a/integer_discrete_flows/media/IDF_slides.pdf b/integer_discrete_flows/media/IDF_slides.pdf deleted file mode 100644 index 3fc439e..0000000 Binary files a/integer_discrete_flows/media/IDF_slides.pdf and /dev/null differ diff --git a/integer_discrete_flows/models/Model.py b/integer_discrete_flows/models/Model.py deleted file mode 100644 index 67bb4ac..0000000 --- a/integer_discrete_flows/models/Model.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -import models.generative_flows as generative_flows -import numpy as np -from models.utils import Base -from .priors import Prior -from optimization.loss import compute_loss_array -from coding.coder import encode_sample, decode_sample - - -class Normalize(Base): - def __init__(self, args): - super().__init__() - self.n_bits = args.n_bits - self.variable_type = args.variable_type - self.input_size = args.input_size - - def forward(self, x, ldj, reverse=False): - domain = 2.**self.n_bits - - if self.variable_type == 'discrete': - # Discrete variables will be measured on intervals sized 1/domain. - # Hence, there is no need to change the log Jacobian determinant. - dldj = 0 - elif self.variable_type == 'continuous': - dldj = -np.log(domain) * np.prod(self.input_size) - else: - raise ValueError - - if not reverse: - x = (x - domain / 2) / domain - ldj += dldj - else: - x = x * domain + domain / 2 - ldj -= dldj - - return x, ldj - - -class Model(Base): - """ - The base VAE class containing gated convolutional encoder and decoder - architecture. Can be used as a base class for VAE's with normalizing flows. - """ - - def __init__(self, args): - super().__init__() - self.args = args - self.variable_type = args.variable_type - self.distribution_type = args.distribution_type - - n_channels, height, width = args.input_size - - self.normalize = Normalize(args) - - self.flow = generative_flows.GenerativeFlow( - n_channels, height, width, args) - - self.n_bits = args.n_bits - - self.z_size = self.flow.z_size - - self.prior = Prior(self.z_size, args) - - def dequantize(self, x): - if self.training: - x = x + torch.rand_like(x) - else: - # Required for stability. - alpha = 1e-3 - x = x + alpha + torch.rand_like(x) * (1 - 2 * alpha) - - return x - - def loss(self, pz, z, pys, ys, ldj): - batchsize = z.size(0) - loss, bpd, bpd_per_prior = \ - compute_loss_array(pz, z, pys, ys, ldj, self.args) - - for module in self.modules(): - if hasattr(module, 'auxillary_loss'): - loss += module.auxillary_loss() / batchsize - - return loss, bpd, bpd_per_prior - - def forward(self, x): - """ - Evaluates the model as a whole, encodes and decodes. Note that the log - det jacobian is zero for a plain VAE (without flows), and z_0 = z_k. - """ - # Decode z to x. - - assert x.dtype == torch.uint8 - - x = x.float() - - ldj = torch.zeros_like(x[:, 0, 0, 0]) - if self.variable_type == 'continuous': - x = self.dequantize(x) - elif self.variable_type == 'discrete': - pass - else: - raise ValueError - - x, ldj = self.normalize(x, ldj) - - z, ldj, pys, ys = self.flow(x, ldj, pys=(), ys=()) - - pz, z, ldj = self.prior(z, ldj) - - loss, bpd, bpd_per_prior = self.loss(pz, z, pys, ys, ldj) - - return loss, bpd, bpd_per_prior, pz, z, pys, ys, ldj - - def inverse(self, z, ys): - ldj = torch.zeros_like(z[:, 0, 0, 0]) - x, ldj, pys, py = \ - self.flow(z, ldj, pys=[], ys=ys, reverse=True) - - x, ldj = self.normalize(x, ldj, reverse=True) - - x_uint8 = torch.clamp(x, min=0, max=255).to( - torch.uint8) - - return x_uint8 - - def sample(self, n): - z_sample = self.prior.sample(n) - - ldj = torch.zeros_like(z_sample[:, 0, 0, 0]) - x_sample, ldj, pys, py = \ - self.flow(z_sample, ldj, pys=[], ys=[], reverse=True) - - x_sample, ldj = self.normalize(x_sample, ldj, reverse=True) - - x_sample_uint8 = torch.clamp(x_sample, min=0, max=255).to( - torch.uint8) - - return x_sample_uint8 - - def encode(self, x): - batchsize = x.size(0) - _, _, _, pz, z, pys, ys, _ = self.forward(x) - - pjs = list(pys) + [pz] - js = list(ys) + [z] - - states = [] - - for b in range(batchsize): - state = None - for pj, j in zip(pjs, js): - pj_b = [param[b:b+1] for param in pj] - j_b = j[b:b+1] - - state = encode_sample( - j_b, pj_b, self.variable_type, - self.distribution_type, state=state) - if state is None: - break - - states.append(state) - - return states - - def decode(self, states): - def decode_fn(states, pj): - states = list(states) - j = [] - - for b in range(len(states)): - pj_b = [param[b:b+1] for param in pj] - - states[b], j_b = decode_sample( - states[b], pj_b, self.variable_type, - self.distribution_type) - - j.append(j_b) - - j = torch.cat(j, dim=0) - return states, j - - states, z = self.prior.decode(states, decode_fn=decode_fn) - - ldj = torch.zeros_like(z[:, 0, 0, 0]) - - x, ldj = self.flow.decode(z, ldj, states, decode_fn=decode_fn) - - x, ldj = self.normalize(x, ldj, reverse=True) - x = x.to(dtype=torch.uint8) - - return x diff --git a/integer_discrete_flows/models/__init__.py b/integer_discrete_flows/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/integer_discrete_flows/models/backround.py b/integer_discrete_flows/models/backround.py deleted file mode 100644 index a3ea233..0000000 --- a/integer_discrete_flows/models/backround.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import torch.nn.functional as F -import numpy as np - -from models.utils import Base - - -class RoundStraightThrough(torch.autograd.Function): - - def __init__(self): - super().__init__() - - @staticmethod - def forward(ctx, input): - rounded = torch.round(input, out=None) - return rounded - - @staticmethod - def backward(ctx, grad_output): - grad_input = grad_output.clone() - return grad_input - - -_round_straightthrough = RoundStraightThrough().apply - - -def _stacked_sigmoid(x, temperature, n_approx=3): - - x_ = x - 0.5 - rounded = torch.round(x_) - x_remainder = x_ - rounded - - size = x_.size() - x_remainder = x_remainder.view(size + (1,)) - - translation = torch.arange(n_approx) - n_approx // 2 - translation = translation.to(device=x.DEVICE, dtype=x.dtype) - translation = translation.view([1] * len(size) + [len(translation)]) - out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1) - - return out + rounded - (n_approx // 2) - - -class SmoothRound(Base): - def __init__(self): - self._temperature = None - self._n_approx = None - super().__init__() - self.hard_round = None - - @property - def temperature(self): - return self._temperature - - @temperature.setter - def temperature(self, value): - self._temperature = value - - if self._temperature <= 0.05: - self._n_approx = 1 - elif 0.05 < self._temperature < 0.13: - self._n_approx = 3 - else: - self._n_approx = 5 - - def forward(self, x): - assert self._temperature is not None - assert self._n_approx is not None - assert self.hard_round is not None - - if self.temperature <= 0.25: - h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx) - else: - h = x - - if self.hard_round: - h = _round_straightthrough(h) - - return h - - -class StochasticRound(Base): - def __init__(self): - super().__init__() - self.hard_round = None - - def forward(self, x): - u = torch.rand_like(x) - - h = x + u - 0.5 - - if self.hard_round: - h = _round_straightthrough(h) - - return h - - -class BackRound(Base): - - def __init__(self, args, inverse_bin_width): - """ - BackRound is an approximation to Round that allows for Backpropagation. - - Approximate the round function using a sum of translated sigmoids. - The temperature determines how well the round function is approximated, - i.e., a lower temperature corresponds to a better approximation, at - the cost of more vanishing gradients. - - BackRound supports the following settings: - * By setting hard to True and temperature > 0.25, BackRound - reduces to a round function with a straight through gradient - estimator - * When using 0 < temperature <= 0.25 and hard = True, the - output in the forward pass is equivalent to a round function, but the - gradient is approximated by the gradient of a sum of sigmoids. - * When using hard = False, the output is not constrained to integers. - * When temperature > 0.25 and hard = False, BackRound reduces to - the identity function. - - Arguments - --------- - temperature: float - Temperature used for stacked sigmoid approximated. If temperature - is greater than 0.25, the approximation reduces to the indentiy - function. - hard: bool - If hard is True, a (hard) round is applied before returning. The - gradient for this is approximated using the straight-through - estimator. - """ - super().__init__() - self.inverse_bin_width = inverse_bin_width - self.round_approx = args.round_approx - - if args.round_approx == 'smooth': - self.round = SmoothRound() - elif args.round_approx == 'stochastic': - self.round = StochasticRound() - else: - raise ValueError - - def forward(self, x): - if self.round_approx == 'smooth' or self.round_approx == 'stochastic': - h = x * self.inverse_bin_width - - h = self.round(h) - - return h / self.inverse_bin_width - - else: - raise ValueError diff --git a/integer_discrete_flows/models/coupling.py b/integer_discrete_flows/models/coupling.py deleted file mode 100644 index 8645449..0000000 --- a/integer_discrete_flows/models/coupling.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Collection of flow strategies -""" - -from __future__ import print_function - -import torch -import numpy as np -from models.utils import Base -from .backround import BackRound -from .networks import NN - - -UNIT_TESTING = False - - -class SplitFactorCoupling(Base): - def __init__(self, c_in, factor, height, width, args): - super().__init__() - self.n_channels = args.n_channels - self.kernel = 3 - self.input_channel = c_in - self.round_approx = args.round_approx - - if args.variable_type == 'discrete': - self.round = BackRound( - args, inverse_bin_width=2**args.n_bits) - else: - self.round = None - - self.split_idx = c_in - (c_in // factor) - - self.nn = NN( - args=args, - c_in=self.split_idx, - c_out=c_in - self.split_idx, - height=height, - width=width, - kernel=self.kernel, - nn_type=args.coupling_type) - - def forward(self, z, ldj, reverse=False): - z1 = z[:, :self.split_idx, :, :] - z2 = z[:, self.split_idx:, :, :] - - t = self.nn(z1) - - if self.round is not None: - t = self.round(t) - - if not reverse: - z2 = z2 + t - else: - z2 = z2 - t - - z = torch.cat([z1, z2], dim=1) - - return z, ldj - - -class Coupling(Base): - def __init__(self, c_in, height, width, args): - super().__init__() - - if args.split_quarter: - factor = 4 - elif args.splitfactor > 1: - factor = args.splitfactor - else: - factor = 2 - - self.coupling = SplitFactorCoupling( - c_in, factor, height, width, args=args) - - def forward(self, z, ldj, reverse=False): - return self.coupling(z, ldj, reverse) - - -def test_generative_flow(): - import models.networks as networks - global UNIT_TESTING - - networks.UNIT_TESTING = True - UNIT_TESTING = True - - batch_size = 17 - - input_size = [12, 16, 16] - - class Args(): - def __init__(self): - self.input_size = input_size - self.learn_split = False - self.variable_type = 'continuous' - self.distribution_type = 'logistic' - self.round_approx = 'smooth' - self.coupling_type = 'shallow' - self.conv_type = 'standard' - self.densenet_depth = 8 - self.bottleneck = False - self.n_channels = 512 - self.network1x1 = 'standard' - self.auxilary_freq = -1 - self.actnorm = False - self.LU = False - self.coupling_lifting_L = True - self.splitprior = True - self.split_quarter = True - self.n_levels = 2 - self.n_flows = 2 - self.cond_L = True - self.n_bits = True - - args = Args() - - x = (torch.randint(256, size=[batch_size] + input_size).float() - 128.) / 256. - ldj = torch.zeros_like(x[:, 0, 0, 0]) - - model = Coupling(c_in=12, height=16, width=16, args=args) - - print(model) - - model.set_temperature(1.) - model.enable_hard_round() - - model.eval() - - z, ldj = model(x, ldj, reverse=False) - - # Check if gradient computation works - loss = torch.sum(z**2) - loss.backward() - - recon, ldj = model(z, ldj, reverse=True) - - sse = torch.sum(torch.pow(x - recon, 2)).item() - ae = torch.abs(x - recon).sum() - print('Error in recon: sse {} ae {}'.format(sse / np.prod(input_size), ae)) - - -if __name__ == '__main__': - test_generative_flow() diff --git a/integer_discrete_flows/models/generative_flows.py b/integer_discrete_flows/models/generative_flows.py deleted file mode 100644 index 8b7bcca..0000000 --- a/integer_discrete_flows/models/generative_flows.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Collection of flow strategies -""" - -from __future__ import print_function - -import torch -import numpy as np -from models.utils import Base -from .priors import SplitPrior -from .coupling import Coupling - - -UNIT_TESTING = False - - -def space_to_depth(x): - xs = x.size() - # Pick off every second element - x = x.view(xs[0], xs[1], xs[2] // 2, 2, xs[3] // 2, 2) - # Transpose picked elements next to channels. - x = x.permute((0, 1, 3, 5, 2, 4)).contiguous() - # Combine with channels. - x = x.view(xs[0], xs[1] * 4, xs[2] // 2, xs[3] // 2) - return x - - -def depth_to_space(x): - xs = x.size() - # Pick off elements from channels - x = x.view(xs[0], xs[1] // 4, 2, 2, xs[2], xs[3]) - # Transpose picked elements next to HW dimensions. - x = x.permute((0, 1, 4, 2, 5, 3)).contiguous() - # Combine with HW dimensions. - x = x.view(xs[0], xs[1] // 4, xs[2] * 2, xs[3] * 2) - return x - - -def int_shape(x): - return list(map(int, x.size())) - - -class Flatten(Base): - def forward(self, x): - return x.view(x.size(0), -1) - - -class Reshape(Base): - def __init__(self, shape): - super().__init__() - self.shape = shape - - def forward(self, x): - return x.view(x.size(0), *self.shape) - - -class Reverse(Base): - def __init__(self): - super().__init__() - - def forward(self, z, reverse=False): - flip_idx = torch.arange(z.size(1) - 1, -1, -1).long() - z = z[:, flip_idx, :, :] - return z - - -class Permute(Base): - def __init__(self, n_channels): - super().__init__() - - permutation = np.arange(n_channels, dtype='int') - np.random.shuffle(permutation) - - permutation_inv = np.zeros(n_channels, dtype='int') - permutation_inv[permutation] = np.arange(n_channels, dtype='int') - - self.permutation = torch.from_numpy(permutation) - self.permutation_inv = torch.from_numpy(permutation_inv) - - def forward(self, z, ldj, reverse=False): - if not reverse: - z = z[:, self.permutation, :, :] - else: - z = z[:, self.permutation_inv, :, :] - - return z, ldj - - def InversePermute(self): - inv_permute = Permute(len(self.permutation)) - inv_permute.permutation = self.permutation_inv - inv_permute.permutation_inv = self.permutation - return inv_permute - - -class Squeeze(Base): - def __init__(self): - super().__init__() - - def forward(self, z, ldj, reverse=False): - if not reverse: - z = space_to_depth(z) - else: - z = depth_to_space(z) - return z, ldj - - -class GenerativeFlow(Base): - def __init__(self, n_channels, height, width, args): - super().__init__() - layers = [] - layers.append(Squeeze()) - n_channels *= 4 - height //= 2 - width //= 2 - - for level in range(args.n_levels): - - for i in range(args.n_flows): - perm_layer = Permute(n_channels) - layers.append(perm_layer) - - layers.append( - Coupling(n_channels, height, width, args)) - - if level < args.n_levels - 1: - if args.splitprior_type != 'none': - # Standard splitprior - factor_out = n_channels // 2 - layers.append(SplitPrior(n_channels, factor_out, height, width, args)) - n_channels = n_channels - factor_out - - layers.append(Squeeze()) - n_channels *= 4 - height //= 2 - width //= 2 - - self.layers = torch.nn.ModuleList(layers) - self.z_size = (n_channels, height, width) - - def forward(self, z, ldj, pys=(), ys=(), reverse=False): - if not reverse: - for l, layer in enumerate(self.layers): - if isinstance(layer, (SplitPrior)): - py, y, z, ldj = layer(z, ldj) - pys += (py,) - ys += (y,) - - else: - z, ldj = layer(z, ldj) - - else: - for l, layer in reversed(list(enumerate(self.layers))): - if isinstance(layer, (SplitPrior)): - if len(ys) > 0: - z, ldj = layer.inverse(z, ldj, y=ys[-1]) - # Pop last element - ys = ys[:-1] - else: - z, ldj = layer.inverse(z, ldj, y=None) - - else: - z, ldj = layer(z, ldj, reverse=True) - - return z, ldj, pys, ys - - def decode(self, z, ldj, state, decode_fn): - - for l, layer in reversed(list(enumerate(self.layers))): - if isinstance(layer, SplitPrior): - z, ldj, state = layer.decode(z, ldj, state, decode_fn) - - else: - z, ldj = layer(z, ldj, reverse=True) - - return z, ldj diff --git a/integer_discrete_flows/models/networks.py b/integer_discrete_flows/models/networks.py deleted file mode 100644 index 480f4df..0000000 --- a/integer_discrete_flows/models/networks.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -Collection of flow strategies -""" - -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -from models.utils import Base - - -UNIT_TESTING = False - - -class Conv2dReLU(Base): - def __init__( - self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0, - bias=True): - super().__init__() - - self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding) - - def forward(self, x): - h = self.nn(x) - - y = F.relu(h) - - return y - - -class ResidualBlock(Base): - def __init__(self, n_channels, kernel, Conv2dAct): - super().__init__() - - self.nn = torch.nn.Sequential( - Conv2dAct(n_channels, n_channels, kernel, padding=1), - torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1), - ) - - def forward(self, x): - h = self.nn(x) - h = F.relu(h + x) - return h - - -class DenseLayer(Base): - def __init__(self, args, n_inputs, growth, Conv2dAct): - super().__init__() - - conv1x1 = Conv2dAct( - n_inputs, n_inputs, kernel_size=1, stride=1, - padding=0, bias=True) - - self.nn = torch.nn.Sequential( - conv1x1, - Conv2dAct( - n_inputs, growth, kernel_size=3, stride=1, - padding=1, bias=True), - ) - - def forward(self, x): - h = self.nn(x) - - h = torch.cat([x, h], dim=1) - return h - - -class DenseBlock(Base): - def __init__( - self, args, n_inputs, n_outputs, kernel, Conv2dAct): - super().__init__() - depth = args.densenet_depth - - future_growth = n_outputs - n_inputs - - layers = [] - - for d in range(depth): - growth = future_growth // (depth - d) - - layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct)) - n_inputs += growth - future_growth -= growth - - self.nn = torch.nn.Sequential(*layers) - - def forward(self, x): - return self.nn(x) - - -class Identity(Base): - def __init__(self): - super.__init__() - - def forward(self, x): - return x - - -class NN(Base): - def __init__( - self, args, c_in, c_out, height, width, nn_type, kernel=3): - super().__init__() - - Conv2dAct = Conv2dReLU - n_channels = args.n_channels - - if nn_type == 'shallow': - - if args.network1x1 == 'standard': - conv1x1 = Conv2dAct( - n_channels, n_channels, kernel_size=1, stride=1, - padding=0, bias=False) - - layers = [ - Conv2dAct(c_in, n_channels, kernel, padding=1), - conv1x1] - - layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)] - - elif nn_type == 'resnet': - layers = [ - Conv2dAct(c_in, n_channels, kernel, padding=1), - ResidualBlock(n_channels, kernel, Conv2dAct), - ResidualBlock(n_channels, kernel, Conv2dAct)] - - layers += [ - torch.nn.Conv2d(n_channels, c_out, kernel, padding=1) - ] - - elif nn_type == 'densenet': - layers = [ - DenseBlock( - args=args, - n_inputs=c_in, - n_outputs=n_channels + c_in, - kernel=kernel, - Conv2dAct=Conv2dAct)] - - layers += [ - torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1) - ] - else: - raise ValueError - - self.nn = torch.nn.Sequential(*layers) - - # Set parameters of last conv-layer to zero. - if not UNIT_TESTING: - self.nn[-1].weight.data.zero_() - self.nn[-1].bias.data.zero_() - - def forward(self, x): - return self.nn(x) diff --git a/integer_discrete_flows/models/priors.py b/integer_discrete_flows/models/priors.py deleted file mode 100644 index 5ace8e9..0000000 --- a/integer_discrete_flows/models/priors.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Collection of flow strategies -""" - -from __future__ import print_function - -import torch -import torch.nn.functional as F -from torch.nn import Parameter -from utils.distributions import sample_discretized_logistic, \ - sample_mixture_discretized_logistic, sample_normal, sample_logistic, \ - sample_discretized_normal, sample_mixture_normal -from models.utils import Base -from .networks import NN - - -def sample_prior(px, variable_type, distribution_type, inverse_bin_width): - if variable_type == 'discrete': - if distribution_type == 'logistic': - if len(px) == 2: - return sample_discretized_logistic( - *px, inverse_bin_width=inverse_bin_width) - elif len(px) == 3: - return sample_mixture_discretized_logistic( - *px, inverse_bin_width=inverse_bin_width) - - elif distribution_type == 'normal': - return sample_discretized_normal( - *px, inverse_bin_width=inverse_bin_width) - - elif variable_type == 'continuous': - if distribution_type == 'logistic': - return sample_logistic(*px) - elif distribution_type == 'normal': - if len(px) == 2: - return sample_normal(*px) - elif len(px) == 3: - return sample_mixture_normal(*px) - elif distribution_type == 'steplogistic': - return sample_logistic(*px) - - raise ValueError - - -class Prior(Base): - def __init__(self, size, args): - super().__init__() - c, h, w = size - - self.inverse_bin_width = 2**args.n_bits - self.variable_type = args.variable_type - self.distribution_type = args.distribution_type - self.n_mixtures = args.n_mixtures - - if self.n_mixtures == 1: - self.mu = Parameter(torch.Tensor(c, h, w)) - self.logs = Parameter(torch.Tensor(c, h, w)) - elif self.n_mixtures > 1: - self.mu = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) - self.logs = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) - self.pi_logit = Parameter(torch.Tensor(c, h, w, self.n_mixtures)) - - self.reset_parameters() - - def reset_parameters(self): - self.mu.data.zero_() - - if self.n_mixtures > 1: - self.pi_logit.data.zero_() - for i in range(self.n_mixtures): - self.mu.data[..., i] += i - (self.n_mixtures - 1) / 2. - - self.logs.data.zero_() - - def get_pz(self, n): - if self.n_mixtures == 1: - mu = self.mu.repeat(n, 1, 1, 1) - logs = self.logs.repeat(n, 1, 1, 1) # scaling scale - return mu, logs - - elif self.n_mixtures > 1: - pi = F.softmax(self.pi_logit, dim=-1) - mu = self.mu.repeat(n, 1, 1, 1, 1) - logs = self.logs.repeat(n, 1, 1, 1, 1) - pi = pi.repeat(n, 1, 1, 1, 1) - return mu, logs, pi - - def forward(self, z, ldj): - pz = self.get_pz(z.size(0)) - - return pz, z, ldj - - def sample(self, n): - pz = self.get_pz(n) - - z_sample = sample_prior(pz, self.variable_type, self.distribution_type, self.inverse_bin_width) - - return z_sample - - def decode(self, states, decode_fn): - pz = self.get_pz(n=len(states)) - - states, z = decode_fn(states, pz) - return states, z - - -class SplitPrior(Base): - def __init__(self, c_in, factor_out, height, width, args): - super().__init__() - - self.split_idx = c_in - factor_out - self.inverse_bin_width = 2**args.n_bits - self.variable_type = args.variable_type - self.distribution_type = args.distribution_type - self.input_channel = c_in - - self.nn = NN( - args=args, - c_in=c_in - factor_out, - c_out=factor_out * 2, - height=height, - width=width, - nn_type=args.splitprior_type) - - def get_py(self, z): - h = self.nn(z) - mu = h[:, ::2, :, :] - logs = h[:, 1::2, :, :] - - py = [mu, logs] - - return py - - def split(self, z): - z1 = z[:, :self.split_idx, :, :] - y = z[:, self.split_idx:, :, :] - return z1, y - - def combine(self, z, y): - result = torch.cat([z, y], dim=1) - - return result - - def forward(self, z, ldj): - z, y = self.split(z) - - py = self.get_py(z) - - return py, y, z, ldj - - def inverse(self, z, ldj, y): - # Sample if y is not given. - if y is None: - py = self.get_py(z) - y = sample_prior(py, self.variable_type, self.distribution_type, self.inverse_bin_width) - - z = self.combine(z, y) - - return z, ldj - - def decode(self, z, ldj, states, decode_fn): - py = self.get_py(z) - states, y = decode_fn(states, py) - return self.combine(z, y), ldj, states diff --git a/integer_discrete_flows/models/utils.py b/integer_discrete_flows/models/utils.py deleted file mode 100644 index 958f760..0000000 --- a/integer_discrete_flows/models/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch - - -class Base(torch.nn.Module): - """ - The base class for modules. That contains a disable round mode - """ - - def __init__(self): - super().__init__() - - def _set_child_attribute(self, attr, value): - r"""Sets the module in rounding mode. - - This has any effect only on certain modules if variable type is - discrete. - - Returns: - Module: self - """ - if hasattr(self, attr): - setattr(self, attr, value) - - for module in self.modules(): - if hasattr(module, attr): - setattr(module, attr, value) - return self - - def set_temperature(self, value): - self._set_child_attribute("temperature", value) - - def enable_hard_round(self, mode=True): - self._set_child_attribute("hard_round", mode) - - def disable_hard_round(self, mode=True): - self.enable_hard_round(not mode) diff --git a/integer_discrete_flows/optimization/__init__.py b/integer_discrete_flows/optimization/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/integer_discrete_flows/optimization/loss.py b/integer_discrete_flows/optimization/loss.py deleted file mode 100644 index 6cdd704..0000000 --- a/integer_discrete_flows/optimization/loss.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import print_function - -import numpy as np -import torch -from utils.distributions import log_discretized_logistic, \ - log_mixture_discretized_logistic, log_normal, log_discretized_normal, \ - log_logistic, log_mixture_normal -from models.backround import _round_straightthrough - - -def compute_log_ps(pxs, xs, args): - # Add likelihoods of intermediate representations. - inverse_bin_width = 2.**args.n_bits - - log_pxs = [] - for px, x in zip(pxs, xs): - - if args.variable_type == 'discrete': - if args.distribution_type == 'logistic': - log_px = log_discretized_logistic( - x, *px, inverse_bin_width=inverse_bin_width) - elif args.distribution_type == 'normal': - log_px = log_discretized_normal( - x, *px, inverse_bin_width=inverse_bin_width) - elif args.variable_type == 'continuous': - if args.distribution_type == 'logistic': - log_px = log_logistic(x, *px) - elif args.distribution_type == 'normal': - log_px = log_normal(x, *px) - elif args.distribution_type == 'steplogistic': - x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width - log_px = log_discretized_logistic( - x, *px, inverse_bin_width=inverse_bin_width) - - log_pxs.append( - torch.sum(log_px, dim=[1, 2, 3])) - - return log_pxs - - -def compute_log_pz(pz, z, args): - inverse_bin_width = 2.**args.n_bits - - if args.variable_type == 'discrete': - if args.distribution_type == 'logistic': - if args.n_mixtures == 1: - log_pz = log_discretized_logistic( - z, pz[0], pz[1], inverse_bin_width=inverse_bin_width) - else: - log_pz = log_mixture_discretized_logistic( - z, pz[0], pz[1], pz[2], - inverse_bin_width=inverse_bin_width) - elif args.distribution_type == 'normal': - log_pz = log_discretized_normal( - z, *pz, inverse_bin_width=inverse_bin_width) - - elif args.variable_type == 'continuous': - if args.distribution_type == 'logistic': - log_pz = log_logistic(z, *pz) - elif args.distribution_type == 'normal': - if args.n_mixtures == 1: - log_pz = log_normal(z, *pz) - else: - log_pz = log_mixture_normal(z, *pz) - elif args.distribution_type == 'steplogistic': - z = _round_straightthrough(z * 256.) / 256. - log_pz = log_discretized_logistic(z, *pz) - - log_pz = torch.sum( - log_pz, - dim=[1, 2, 3]) - - return log_pz - - -def compute_loss_function(pz, z, pys, ys, ldj, args): - """ - Computes the cross entropy loss function while summing over batch dimension, not averaged! - :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits - :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. - :param z_mu: mean of z_0 - :param z_var: variance of z_0 - :param z_0: first stochastic latent variable - :param z_k: last stochastic latent variable - :param ldj: log det jacobian - :param args: global parameter settings - :param beta: beta for kl loss - :return: loss, ce, kl - """ - batch_size = z.size(0) - - # Get array loss, sum over batch - loss_array, bpd_array, bpd_per_prior_array = \ - compute_loss_array(pz, z, pys, ys, ldj, args) - - loss = torch.mean(loss_array) - bpd = torch.mean(bpd_array).item() - bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array] - - return loss, bpd, bpd_per_prior - - -def convert_bpd(log_p, input_size): - return -log_p / (np.prod(input_size) * np.log(2.)) - - -def compute_loss_array(pz, z, pys, ys, ldj, args): - """ - Computes the cross entropy loss function while summing over batch dimension, not averaged! - :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits - :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. - :param z_mu: mean of z_0 - :param z_var: variance of z_0 - :param z_0: first stochastic latent variable - :param z_k: last stochastic latent variable - :param ldj: log det jacobian - :param args: global parameter settings - :param beta: beta for kl loss - :return: loss, ce, kl - """ - bpd_per_prior = [] - - # Likelihood of final representation. - log_pz = compute_log_pz(pz, z, args) - - bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size)) - - log_p = log_pz - - # Add likelihoods of intermediate representations. - if ys: - log_pys = compute_log_ps(pys, ys, args) - - for log_py in log_pys: - log_p += log_py - - bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size)) - - log_p += ldj - - loss = -log_p - bpd = convert_bpd(log_p.detach(), args.input_size) - - return loss, bpd, bpd_per_prior - - -def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args): - return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args) diff --git a/integer_discrete_flows/optimization/training.py b/integer_discrete_flows/optimization/training.py deleted file mode 100644 index 881d5c5..0000000 --- a/integer_discrete_flows/optimization/training.py +++ /dev/null @@ -1,174 +0,0 @@ -from __future__ import print_function -import torch - -from optimization.loss import calculate_loss -from utils.visual_evaluation import plot_reconstructions - -import numpy as np - - -def train(epoch, train_loader, model, opt, args): - model.train() - train_loss = np.zeros(len(train_loader)) - train_bpd = np.zeros(len(train_loader)) - - num_data = 0 - - for batch_idx, (data, _) in enumerate(train_loader): - data = data.view(-1, *args.input_size) - - data = data.to(args.DEVICE) - - opt.zero_grad() - loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data) - - loss = torch.mean(loss) - bpd = torch.mean(bpd) - bpd_per_prior = [torch.mean(i) for i in bpd_per_prior] - - loss.backward() - loss = loss.item() - train_loss[batch_idx] = loss - train_bpd[batch_idx] = bpd - - ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2) - - opt.step() - - num_data += len(data) - - if batch_idx % args.log_interval == 0: - perc = 100. * batch_idx / len(train_loader) - - tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}' - print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj)) - - print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256)) - - print('z bpd: {:.3f}'.format(bpd_per_prior[0])) - for i in range(1, len(bpd_per_prior)): - print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i])) - - print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3))) - print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3))) - if len(pz) == 3: - print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3))) - - for i, py in enumerate(pys): - print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3))) - print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3))) - - from utils.visual_evaluation import plot_images - import os - if not os.path.exists(args.snap_dir + 'training/'): - os.makedirs(args.snap_dir + 'training/') - - print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format( - epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader))) - - return train_loss, train_bpd - - -def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0): - model.eval() - - loss_type = 'bpd' - - def analyse(data_loader, plot=False): - bpds = [] - batch_idx = 0 - with torch.no_grad(): - for data, _ in data_loader: - batch_idx += 1 - - if args.cuda: - data = data.cuda() - - data = data.view(-1, *args.input_size) - - loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \ - model(data) - loss = torch.mean(loss).item() - batch_bpd = torch.mean(batch_bpd).item() - - bpds.append(batch_bpd) - - bpd = np.mean(bpds) - - with torch.no_grad(): - if not testing and plot: - x_sample = model_sample.sample(n=100) - - try: - plot_reconstructions( - x_sample, bpd, loss_type, epoch, args) - except: - print('Not plotting') - - return bpd - - bpd_train = analyse(train_loader) - bpd_val = analyse(val_loader, plot=True) - - with open(file, 'a') as ff: - msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format( - epoch, - bpd_train, - bpd_val) - print(msg, file=ff) - - loss = bpd_val * np.prod(args.input_size) * np.log(2.) - bpd = bpd_val - - file = None - - # Compute log-likelihood - with torch.no_grad(): - if testing: - test_data = val_loader.dataset.data_tensor - - if args.cuda: - test_data = test_data.cuda() - - print('Computing log-likelihood on test set') - - model.eval() - - log_likelihood = analyse(test_data) - - else: - log_likelihood = None - nll_bpd = None - - if file is None: - if testing: - print('====> Test set loss: {:.4f}'.format(loss)) - print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood)) - - print('====> Test set bpd (elbo): {:.4f}'.format(bpd)) - print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/ - (np.prod(args.input_size) * np.log(2.)))) - - else: - print('====> Validation set loss: {:.4f}'.format(loss)) - print('====> Validation set bpd: {:.4f}'.format(bpd)) - else: - with open(file, 'a') as ff: - if testing: - print('====> Test set loss: {:.4f}'.format(loss), file=ff) - print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff) - - print('====> Test set bpd: {:.4f}'.format(bpd), file=ff) - print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood / - (np.prod(args.input_size) * np.log(2.))), - file=ff) - - else: - print('====> Validation set loss: {:.4f}'.format(loss), file=ff) - print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))), - file=ff) - - if not testing: - return loss, bpd - else: - return log_likelihood, nll_bpd diff --git a/integer_discrete_flows/utils/__init__.py b/integer_discrete_flows/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/integer_discrete_flows/utils/distributions.py b/integer_discrete_flows/utils/distributions.py deleted file mode 100644 index 7cc2381..0000000 --- a/integer_discrete_flows/utils/distributions.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import print_function -import torch -import torch.utils.data -import torch.nn.functional as F - -import numpy as np -import math - -MIN_EPSILON = 1e-5 -MAX_EPSILON = 1.-1e-5 - - -PI = math.pi - - -def log_min_exp(a, b, epsilon=1e-8): - """ - Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion. - Using: - log(exp(a) - exp(b)) - c + log(exp(a-c) - exp(b-c)) - a + log(1 - exp(b-a)) - And note that we assume b < a always. - """ - y = a + torch.log(1 - torch.exp(b - a) + epsilon) - - return y - - -def log_normal(x, mean, logvar): - logp = -0.5 * logvar - logp += -0.5 * np.log(2 * PI) - logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar) - return logp - - -def log_mixture_normal(x, mean, logvar, pi): - x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) - - logp_mixtures = log_normal(x, mean, logvar) - - logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8) - - return logp - - -def sample_normal(mean, logvar): - y = torch.randn_like(mean) - - x = torch.exp(0.5 * logvar) * y + mean - - return x - - -def sample_mixture_normal(mean, logvar, pi): - b, c, h, w, n_mixtures = tuple(map(int, pi.size())) - pi = pi.view(b * c * h * w, n_mixtures) - sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) - - # Select mixture params - mean = mean.view(b * c * h * w, n_mixtures) - mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) - logvar = logvar.view(b * c * h * w, n_mixtures) - logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) - - y = sample_normal(mean, logvar) - - return y - - -def log_logistic(x, mean, logscale): - """ - pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale - """ - scale = torch.exp(logscale) - - u = (x - mean) / scale - - logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale - - return logp - - -def sample_logistic(mean, logscale): - y = torch.rand_like(mean) - - x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean - - return x - - -def log_discretized_logistic(x, mean, logscale, inverse_bin_width): - scale = torch.exp(logscale) - - logp = log_min_exp( - F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale), - F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale)) - - return logp - - -def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width): - scale = torch.exp(logscale) - - cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) - - return cdf - - -def sample_discretized_logistic(mean, logscale, inverse_bin_width): - x = sample_logistic(mean, logscale) - - x = torch.round(x * inverse_bin_width) / inverse_bin_width - return x - - -def normal_cdf(value, loc, std): - return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2))) - - -def log_discretized_normal(x, mean, logvar, inverse_bin_width): - std = torch.exp(0.5 * logvar) - log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7) - - return log_p - - -def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width): - std = torch.exp(0.5 * logvar) - - x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) - - p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) - - p = torch.sum(p * pi, dim=-1) - - logp = torch.log(p + 1e-8) - - return logp - - -def sample_discretized_normal(mean, logvar, inverse_bin_width): - y = torch.randn_like(mean) - - x = torch.exp(0.5 * logvar) * y + mean - - x = torch.round(x * inverse_bin_width) / inverse_bin_width - - return x - - -def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width): - scale = torch.exp(logscale) - - x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) - - p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \ - - torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale) - - p = torch.sum(p * pi, dim=-1) - - logp = torch.log(p + 1e-8) - - return logp - - -def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width): - scale = torch.exp(logscale) - - x = x[..., None] - - cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) - - cdf = torch.sum(cdfs * pi, dim=-1) - - return cdf - - -def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width): - # Sample mixtures - b, c, h, w, n_mixtures = tuple(map(int, pi.size())) - pi = pi.view(b * c * h * w, n_mixtures) - sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) - - # Select mixture params - mean = mean.view(b * c * h * w, n_mixtures) - mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) - logs = logs.view(b * c * h * w, n_mixtures) - logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) - - y = torch.rand_like(mean) - x = torch.exp(logs) * torch.log(y / (1 - y)) + mean - - x = torch.round(x * inverse_bin_width) / inverse_bin_width - - return x - - -def log_multinomial(logits, targets): - return -F.cross_entropy(logits, targets, reduction='none') - - -def sample_multinomial(logits): - b, n_categories, c, h, w = logits.size() - logits = logits.permute(0, 2, 3, 4, 1) - p = F.softmax(logits, dim=-1) - p = p.view(b * c * h * w, n_categories) - x = torch.multinomial(p, num_samples=1).view(b, c, h, w) - return x diff --git a/integer_discrete_flows/utils/load_data.py b/integer_discrete_flows/utils/load_data.py deleted file mode 100644 index b743e69..0000000 --- a/integer_discrete_flows/utils/load_data.py +++ /dev/null @@ -1,264 +0,0 @@ -from __future__ import print_function - -import numbers - -import torch -import torch.utils.data as data_utils -import pickle -from scipy.io import loadmat - -import numpy as np - -import os -from torch.utils.data import Dataset -import torchvision -from torchvision import transforms -from torchvision.transforms import functional as vf -from torch.utils.data import ConcatDataset - -from PIL import Image - -import os -import os.path -from os.path import join -import sys -import tarfile - - -class ToTensorNoNorm(): - def __call__(self, X_i): - return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1) - - -class PadToMultiple(object): - def __init__(self, multiple, fill=0, padding_mode='constant'): - assert isinstance(multiple, numbers.Number) - assert isinstance(fill, (numbers.Number, str, tuple)) - assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] - - self.multiple = multiple - self.fill = fill - self.padding_mode = padding_mode - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be padded. - Returns: - PIL Image: Padded image. - """ - w, h = img.size - m = self.multiple - nw = (w // m + int((w % m) != 0)) * m - nh = (h // m + int((h % m) != 0)) * m - padw = nw - w - padh = nh - h - - out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode) - return out - - def __repr__(self): - return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\ - format(self.mulitple, self.fill, self.padding_mode) - - -class CustomTensorDataset(Dataset): - """Dataset wrapping tensors. - - Each sample will be retrieved by indexing tensors along the first dimension. - - Arguments: - *tensors (Tensor): tensors that have the same size of the first dimension. - """ - - def __init__(self, *tensors, transform=None): - assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) - self.tensors = tensors - self.transform = transform - - def __getitem__(self, index): - from PIL import Image - - X, y = self.tensors - X_i, y_i, = X[index], y[index] - - if self.transform: - X_i = self.transform(X_i) - X_i = torch.from_numpy(np.array(X_i, copy=False)) - X_i = X_i.permute(2, 0, 1) - - return X_i, y_i - - def __len__(self): - return self.tensors[0].size(0) - - -def load_cifar10(args, **kwargs): - # set args - args.input_size = [3, 32, 32] - args.input_type = 'continuous' - args.dynamic_binarization = False - - from keras.datasets import cifar10 - (x_train, y_train), (x_test, y_test) = cifar10.load_data() - - x_train = x_train.transpose(0, 3, 1, 2) - x_test = x_test.transpose(0, 3, 1, 2) - - import math - - if args.data_augmentation_level == 2: - data_transform = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'), - transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), - transforms.CenterCrop(32) - ]) - elif args.data_augmentation_level == 1: - data_transform = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - ]) - else: - data_transform = transforms.Compose([ - transforms.ToPILImage(), - ]) - - x_val = x_train[-10000:] - y_val = y_train[-10000:] - - x_train = x_train[:-10000] - y_train = y_train[:-10000] - - train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform) - train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) - - validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)) - val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) - - test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)) - test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) - - return train_loader, val_loader, test_loader, args - - -def extract_tar(tarpath): - assert tarpath.endswith('.tar') - - startdir = tarpath[:-4] + '/' - - if os.path.exists(startdir): - return startdir - - print('Extracting', tarpath) - - with tarfile.open(name=tarpath) as tar: - t = 0 - done = False - while not done: - path = join(startdir, 'images{}'.format(t)) - os.makedirs(path, exist_ok=True) - - print(path) - - for i in range(50000): - member = tar.next() - - if member is None: - done = True - break - - # Skip directories - while member.isdir(): - member = tar.next() - if member is None: - done = True - break - - member.name = member.name.split('/')[-1] - - tar.extract(member, path=path) - - t += 1 - - return startdir - - -def load_imagenet(resolution, args, **kwargs): - assert resolution == 32 or resolution == 64 - - args.input_size = [3, resolution, resolution] - - trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution) - valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution) - - trainpath = extract_tar(trainpath) - valpath = extract_tar(valpath) - - data_transform = transforms.Compose([ - ToTensorNoNorm() - ]) - - print('Starting loading ImageNet') - - imagenet_data = torchvision.datasets.ImageFolder( - trainpath, - transform=data_transform) - - print('Number of data images', len(imagenet_data)) - - val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False) - train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs) - - train_dataset = torch.utils.data.dataset.Subset( - imagenet_data, train_idcs) - val_dataset = torch.utils.data.dataset.Subset( - imagenet_data, val_idcs) - - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.batch_size, - shuffle=True, - **kwargs) - - val_loader = torch.utils.data.DataLoader( - val_dataset, - batch_size=args.batch_size, - shuffle=False, - **kwargs) - - test_dataset = torchvision.datasets.ImageFolder( - valpath, - transform=data_transform) - - print('Number of val images:', len(test_dataset)) - - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=args.batch_size, - shuffle=False, - **kwargs) - - return train_loader, val_loader, test_loader, args - - -def load_dataset(args, **kwargs): - - if args.dataset == 'cifar10': - train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs) - elif args.dataset == 'imagenet32': - train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs) - elif args.dataset == 'imagenet64': - train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs) - else: - raise Exception('Wrong name of the dataset!') - - return train_loader, val_loader, test_loader, args - - -if __name__ == '__main__': - class Args(): - def __init__(self): - self.batch_size = 128 - train_loader, val_loader, test_loader, args = load_imagenet32(Args()) diff --git a/integer_discrete_flows/utils/log_likelihood.py b/integer_discrete_flows/utils/log_likelihood.py deleted file mode 100644 index 4f47e7c..0000000 --- a/integer_discrete_flows/utils/log_likelihood.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import print_function -import numpy as np -from scipy.misc import logsumexp -from optimization.loss import calculate_loss_array - - -def calculate_likelihood(X, model, args, S=5000, MB=500): - - # set auxiliary variables for number of training and test sets - N_test = X.size(0) - - X = X.view(-1, *args.input_size) - - likelihood_test = [] - - if S <= MB: - R = 1 - else: - R = S // MB - S = MB - - for j in range(N_test): - if j % 100 == 0: - print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100)) - - x_single = X[j].unsqueeze(0) - - a = [] - for r in range(0, R): - # Repeat it for all training points - x = x_single.expand(S, *x_single.size()[1:]).contiguous() - - x_mean, z_mu, z_var, ldj, z0, zk = model(x) - - a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args) - - a.append(-a_tmp.cpu().data.numpy()) - - # calculate max - a = np.asarray(a) - a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) - likelihood_x = logsumexp(a) - likelihood_test.append(likelihood_x - np.log(len(a))) - - likelihood_test = np.array(likelihood_test) - - nll = -np.mean(likelihood_test) - - if args.input_type == 'multinomial': - bpd = nll/(np.prod(args.input_size) * np.log(2.)) - elif args.input_type == 'binary': - bpd = 0. - else: - raise ValueError('invalid input type!') - - return nll, bpd diff --git a/integer_discrete_flows/utils/plotting.py b/integer_discrete_flows/utils/plotting.py deleted file mode 100644 index 6259834..0000000 --- a/integer_discrete_flows/utils/plotting.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import division -from __future__ import print_function - -import numpy as np -import matplotlib -# noninteractive background -matplotlib.use('Agg') -import matplotlib.pyplot as plt - - -def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None): - """ - Plots train_loss and validation loss as a function of optimization iteration - :param train_loss: np.array of train_loss (1D or 2D) - :param validation_loss: np.array of validation loss (1D or 2D) - :param fname: output file name - :param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied - accross training curves. - :return: None - """ - - plt.close() - - matplotlib.rcParams.update({'font.size': 14}) - matplotlib.rcParams['mathtext.fontset'] = 'stix' - matplotlib.rcParams['font.family'] = 'STIXGeneral' - - if len(train_loss.shape) == 1: - # Single training curve - fig, ax = plt.subplots(nrows=1, ncols=1) - figsize = (6, 4) - - if train_loss.shape[0] == validation_loss.shape[0]: - # validation score evaluated every iteration - x = np.arange(train_loss.shape[0]) - ax.plot(x, train_loss, '-', lw=2., color='black', label='train') - ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') - - elif train_loss.shape[0] % validation_loss.shape[0] == 0: - # validation score evaluated every epoch - x = np.arange(train_loss.shape[0]) - ax.plot(x, train_loss, '-', lw=2., color='black', label='train') - - x = np.arange(validation_loss.shape[0]) - x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0] - ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') - else: - raise ValueError('Length of train_loss and validation_loss must be equal or divisible') - - miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. - maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. - ax.set_ylim([miny, maxy]) - - elif len(train_loss.shape) == 2: - # Multiple training curves - - cmap = plt.cm.brg - - cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0]) - scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap) - - fig, ax = plt.subplots(nrows=1, ncols=1) - figsize = (6, 4) - - if labels is None: - labels = ['%d' % i for i in range(train_loss.shape[0])] - - if train_loss.shape[1] == validation_loss.shape[1]: - for i in range(train_loss.shape[0]): - color_val = scalarMap.to_rgba(i) - - # validation score evaluated every iteration - x = np.arange(train_loss.shape[0]) - ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) - ax.plot(x, validation_loss[i], '--', lw=2., color=color_val) - - elif train_loss.shape[1] % validation_loss.shape[1] == 0: - for i in range(train_loss.shape[0]): - color_val = scalarMap.to_rgba(i) - - # validation score evaluated every epoch - x = np.arange(train_loss.shape[1]) - ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) - - x = np.arange(validation_loss.shape[1]) - x = (x+1) * train_loss.shape[1] / validation_loss.shape[1] - ax.plot(x, validation_loss[i], '-', lw=2., color=color_val) - - miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. - maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. - ax.set_ylim([miny, maxy]) - - else: - raise ValueError('train_loss and validation_loss must be 1D or 2D arrays') - - ax.set_xlabel('iteration') - ax.set_ylabel('loss') - plt.title('Training and validation loss') - - fig.set_size_inches(figsize) - fig.subplots_adjust(hspace=0.1) - plt.savefig(fname, bbox_inches='tight') - - plt.close() diff --git a/integer_discrete_flows/utils/visual_evaluation.py b/integer_discrete_flows/utils/visual_evaluation.py deleted file mode 100644 index 340fb69..0000000 --- a/integer_discrete_flows/utils/visual_evaluation.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import print_function -import os -import numpy as np -import imageio - - -def plot_reconstructions(recon_mean, loss, loss_type, epoch, args): - if epoch == 1: - if not os.path.exists(args.snap_dir + 'reconstruction/'): - os.makedirs(args.snap_dir + 'reconstruction/') - if loss_type == 'bpd': - fname = str(epoch) + '_bpd_%5.3f' % loss - elif loss_type == 'elbo': - fname = str(epoch) + '_elbo_%6.4f' % loss - plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname) - - -def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10): - batch, channels, height, width = x_sample.shape - - print(x_sample.shape) - - mosaic = np.zeros((height * size_y, width * size_x, channels)) - - for j in range(size_y): - for i in range(size_x): - idx = j * size_x + i - - image = x_sample[idx] - - mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \ - image.transpose(1, 2, 0) - - # Remove channel for BW images - mosaic = mosaic.squeeze() - - imageio.imwrite(dir + file_name + '.png', mosaic) diff --git a/CNN-model/main_cnn.py b/main.py similarity index 98% rename from CNN-model/main_cnn.py rename to main.py index 1789572..d43cab5 100644 --- a/CNN-model/main_cnn.py +++ b/main.py @@ -89,7 +89,7 @@ def main(): model = None if args.model_path is not None: - print("Loading the model...") + print("Loading the models...") model = torch.load(args.model_path) trainer: Trainer = OptunaTrainer(n_trials=3 if args.debug else None) if args.method == "optuna" else FullTrainer() diff --git a/CNN-model/model/__init__.py b/models/__init__.py similarity index 100% rename from CNN-model/model/__init__.py rename to models/__init__.py diff --git a/models/cnn/__init__.py b/models/cnn/__init__.py new file mode 100644 index 0000000..153551d --- /dev/null +++ b/models/cnn/__init__.py @@ -0,0 +1 @@ +from .cnn import CNNPredictor \ No newline at end of file diff --git a/CNN-model/model/cnn.py b/models/cnn/cnn.py similarity index 100% rename from CNN-model/model/cnn.py rename to models/cnn/cnn.py diff --git a/CNN-model/models/final_model.pt b/saved_models/final_model.pt similarity index 100% rename from CNN-model/models/final_model.pt rename to saved_models/final_model.pt diff --git a/CNN-model/trainers/FullTrainer.py b/trainers/FullTrainer.py similarity index 95% rename from CNN-model/trainers/FullTrainer.py rename to trainers/FullTrainer.py index fecfe90..7f7882a 100644 --- a/CNN-model/trainers/FullTrainer.py +++ b/trainers/FullTrainer.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from .trainer import Trainer from .train import train -from utils import print_losses +from ..utils import print_losses class FullTrainer(Trainer): def execute( diff --git a/CNN-model/trainers/OptunaTrainer.py b/trainers/OptunaTrainer.py similarity index 94% rename from CNN-model/trainers/OptunaTrainer.py rename to trainers/OptunaTrainer.py index 6f0b3b9..eb896fc 100644 --- a/CNN-model/trainers/OptunaTrainer.py +++ b/trainers/OptunaTrainer.py @@ -7,7 +7,7 @@ from torch import nn as nn from torch.utils.data import DataLoader from .trainer import Trainer -from model.cnn import CNNPredictor +from ..models.cnn import CNNPredictor from .train import train @@ -59,4 +59,4 @@ class OptunaTrainer(Trainer): best_model = CNNPredictor( **best_params ) - torch.save(best_model, "models/final_model.pt") + torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt") diff --git a/CNN-model/trainers/__init__.py b/trainers/__init__.py similarity index 100% rename from CNN-model/trainers/__init__.py rename to trainers/__init__.py diff --git a/CNN-model/trainers/train.py b/trainers/train.py similarity index 100% rename from CNN-model/trainers/train.py rename to trainers/train.py diff --git a/CNN-model/trainers/trainer.py b/trainers/trainer.py similarity index 100% rename from CNN-model/trainers/trainer.py rename to trainers/trainer.py diff --git a/CNN-model/utils/__init__.py b/utils/__init__.py similarity index 100% rename from CNN-model/utils/__init__.py rename to utils/__init__.py diff --git a/CNN-model/utils/utils.py b/utils/utils.py similarity index 100% rename from CNN-model/utils/utils.py rename to utils/utils.py