feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
188
integer_discrete_flows/experiment_coding.py
Normal file
188
integer_discrete_flows/experiment_coding.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
from utils.load_data import load_dataset
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
||||
|
||||
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
||||
metavar='DATASET',
|
||||
help='Dataset choice.')
|
||||
|
||||
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
|
||||
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
||||
|
||||
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
||||
help='Evaluate per how many epochs')
|
||||
|
||||
|
||||
# optimization settings
|
||||
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
||||
help='number of epochs to train (default: 2000)')
|
||||
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
||||
help='number of early stopping epochs')
|
||||
|
||||
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
|
||||
help='input batch size for training (default: 100)')
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
||||
help='learning rate')
|
||||
parser.add_argument('--warmup', type=int, default=10,
|
||||
help='number of warmup epochs')
|
||||
|
||||
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
||||
help='data augmentation level')
|
||||
|
||||
parser.add_argument('--no_decode', action='store_true', default=False,
|
||||
help='disables decoding')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
|
||||
|
||||
def encode_images(img, model, decode):
|
||||
batchsize, img_c, img_h, img_w = img.size()
|
||||
c, h, w = model.args.input_size
|
||||
|
||||
assert img_h == img_w and h == w
|
||||
|
||||
if img_h != h:
|
||||
assert img_h % h == 0
|
||||
steps = img_h // h
|
||||
|
||||
states = [[] for i in range(batchsize)]
|
||||
state_sizes = [0 for i in range(batchsize)]
|
||||
bpd = [0 for i in range(batchsize)]
|
||||
error = 0
|
||||
|
||||
for j in range(steps):
|
||||
for i in range(steps):
|
||||
r = encode_patches(
|
||||
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
|
||||
for b in range(batchsize):
|
||||
|
||||
if r[0][b] is None:
|
||||
states[b].append(None)
|
||||
else:
|
||||
states[b].extend(r[0][b])
|
||||
state_sizes[b] += r[1][b]
|
||||
bpd[b] += r[2][b] / steps**2
|
||||
error += r[3]
|
||||
return states, state_sizes, bpd, error
|
||||
else:
|
||||
return encode_patches(img, model, decode)
|
||||
|
||||
|
||||
def encode_patches(imgs, model, decode):
|
||||
batchsize, img_c, img_h, img_w = imgs.size()
|
||||
c, h, w = model.args.input_size
|
||||
assert img_h == h and img_w == w
|
||||
|
||||
states = model.encode(imgs)
|
||||
|
||||
bpd = model.forward(imgs)[1].cpu().numpy()
|
||||
|
||||
state_sizes = []
|
||||
error = 0
|
||||
|
||||
for b in range(batchsize):
|
||||
if states[b] is None:
|
||||
# Using escape bit ;)
|
||||
state_sizes += [8 * img_c * img_h * img_w + 1]
|
||||
|
||||
# Error remains unchanged.
|
||||
print('Escaping, not encoding.')
|
||||
|
||||
else:
|
||||
if decode:
|
||||
x_recon = model.decode([states[b]])
|
||||
|
||||
error += torch.sum(
|
||||
torch.abs(x_recon.int() - imgs[b].int())).item()
|
||||
|
||||
# Append state plus an escape bit
|
||||
state_sizes += [32 * len(states[b]) + 1]
|
||||
|
||||
return states, state_sizes, bpd, error
|
||||
|
||||
|
||||
def run(args, kwargs):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
args.snap_dir = snap_dir = \
|
||||
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
||||
|
||||
# ==================================================================================================================
|
||||
# SNAPSHOTS
|
||||
# ==================================================================================================================
|
||||
|
||||
# ==================================================================================================================
|
||||
# LOAD DATA
|
||||
# ==================================================================================================================
|
||||
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
||||
|
||||
final_model = torch.load(snap_dir + 'a.model')
|
||||
if hasattr(final_model, 'module'):
|
||||
final_model = final_model.module
|
||||
final_model = final_model.cuda()
|
||||
|
||||
sizes = []
|
||||
errors = []
|
||||
bpds = []
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
t = 0
|
||||
with torch.no_grad():
|
||||
for data, _ in test_loader:
|
||||
if args.cuda:
|
||||
data = data.cuda()
|
||||
|
||||
state, state_sizes, bpd, error = \
|
||||
encode_images(data, final_model, decode=not args.no_decode)
|
||||
|
||||
errors += [error]
|
||||
bpds.extend(bpd)
|
||||
sizes.extend(state_sizes)
|
||||
|
||||
t += len(data)
|
||||
|
||||
print(
|
||||
'Examples: {}/{} bpd compression: {:.3f} error: {},'
|
||||
' analytical bpd {:.3f}'.format(
|
||||
t, len(test_loader.dataset),
|
||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||
np.sum(errors),
|
||||
np.mean(bpds)
|
||||
))
|
||||
|
||||
if args.no_decode:
|
||||
print('Not testing decoding.')
|
||||
else:
|
||||
print('Error: {}'.format(np.sum(errors)))
|
||||
|
||||
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
|
||||
print('Final bpd: {:.3f} error: {}'.format(
|
||||
np.mean(sizes) / np.prod(data.size()[1:]),
|
||||
np.sum(errors)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
run(args, kwargs)
|
||||
Reference in a new issue