188 lines
6.1 KiB
Python
188 lines
6.1 KiB
Python
# !/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import print_function
|
|
import argparse
|
|
import torch
|
|
import torch.utils.data
|
|
import numpy as np
|
|
|
|
from utils.load_data import load_dataset
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
|
|
|
|
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
|
|
metavar='DATASET',
|
|
help='Dataset choice.')
|
|
|
|
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
|
|
help='disables CUDA training')
|
|
|
|
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
|
|
|
|
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
|
|
help='how many batches to wait before logging training status')
|
|
|
|
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
|
|
help='Evaluate per how many epochs')
|
|
|
|
|
|
# optimization settings
|
|
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
|
|
help='number of epochs to train (default: 2000)')
|
|
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
|
|
help='number of early stopping epochs')
|
|
|
|
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
|
|
help='input batch size for training (default: 100)')
|
|
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
|
|
help='learning rate')
|
|
parser.add_argument('--warmup', type=int, default=10,
|
|
help='number of warmup epochs')
|
|
|
|
parser.add_argument('--data_augmentation_level', type=int, default=2,
|
|
help='data augmentation level')
|
|
|
|
parser.add_argument('--no_decode', action='store_true', default=False,
|
|
help='disables decoding')
|
|
|
|
|
|
args = parser.parse_args()
|
|
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
|
|
|
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
|
|
|
|
|
def encode_images(img, model, decode):
|
|
batchsize, img_c, img_h, img_w = img.size()
|
|
c, h, w = model.args.input_size
|
|
|
|
assert img_h == img_w and h == w
|
|
|
|
if img_h != h:
|
|
assert img_h % h == 0
|
|
steps = img_h // h
|
|
|
|
states = [[] for i in range(batchsize)]
|
|
state_sizes = [0 for i in range(batchsize)]
|
|
bpd = [0 for i in range(batchsize)]
|
|
error = 0
|
|
|
|
for j in range(steps):
|
|
for i in range(steps):
|
|
r = encode_patches(
|
|
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
|
|
for b in range(batchsize):
|
|
|
|
if r[0][b] is None:
|
|
states[b].append(None)
|
|
else:
|
|
states[b].extend(r[0][b])
|
|
state_sizes[b] += r[1][b]
|
|
bpd[b] += r[2][b] / steps**2
|
|
error += r[3]
|
|
return states, state_sizes, bpd, error
|
|
else:
|
|
return encode_patches(img, model, decode)
|
|
|
|
|
|
def encode_patches(imgs, model, decode):
|
|
batchsize, img_c, img_h, img_w = imgs.size()
|
|
c, h, w = model.args.input_size
|
|
assert img_h == h and img_w == w
|
|
|
|
states = model.encode(imgs)
|
|
|
|
bpd = model.forward(imgs)[1].cpu().numpy()
|
|
|
|
state_sizes = []
|
|
error = 0
|
|
|
|
for b in range(batchsize):
|
|
if states[b] is None:
|
|
# Using escape bit ;)
|
|
state_sizes += [8 * img_c * img_h * img_w + 1]
|
|
|
|
# Error remains unchanged.
|
|
print('Escaping, not encoding.')
|
|
|
|
else:
|
|
if decode:
|
|
x_recon = model.decode([states[b]])
|
|
|
|
error += torch.sum(
|
|
torch.abs(x_recon.int() - imgs[b].int())).item()
|
|
|
|
# Append state plus an escape bit
|
|
state_sizes += [32 * len(states[b]) + 1]
|
|
|
|
return states, state_sizes, bpd, error
|
|
|
|
|
|
def run(args, kwargs):
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
args.snap_dir = snap_dir = \
|
|
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
|
|
|
|
# ==================================================================================================================
|
|
# SNAPSHOTS
|
|
# ==================================================================================================================
|
|
|
|
# ==================================================================================================================
|
|
# LOAD DATA
|
|
# ==================================================================================================================
|
|
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
|
|
|
|
final_model = torch.load(snap_dir + 'a.model')
|
|
if hasattr(final_model, 'module'):
|
|
final_model = final_model.module
|
|
final_model = final_model.cuda()
|
|
|
|
sizes = []
|
|
errors = []
|
|
bpds = []
|
|
|
|
import time
|
|
start = time.time()
|
|
|
|
t = 0
|
|
with torch.no_grad():
|
|
for data, _ in test_loader:
|
|
if args.cuda:
|
|
data = data.cuda()
|
|
|
|
state, state_sizes, bpd, error = \
|
|
encode_images(data, final_model, decode=not args.no_decode)
|
|
|
|
errors += [error]
|
|
bpds.extend(bpd)
|
|
sizes.extend(state_sizes)
|
|
|
|
t += len(data)
|
|
|
|
print(
|
|
'Examples: {}/{} bpd compression: {:.3f} error: {},'
|
|
' analytical bpd {:.3f}'.format(
|
|
t, len(test_loader.dataset),
|
|
np.mean(sizes) / np.prod(data.size()[1:]),
|
|
np.sum(errors),
|
|
np.mean(bpds)
|
|
))
|
|
|
|
if args.no_decode:
|
|
print('Not testing decoding.')
|
|
else:
|
|
print('Error: {}'.format(np.sum(errors)))
|
|
|
|
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
|
|
print('Final bpd: {:.3f} error: {}'.format(
|
|
np.mean(sizes) / np.prod(data.size()[1:]),
|
|
np.sum(errors)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
run(args, kwargs)
|