This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/experiment_coding.py
2025-11-07 12:54:36 +01:00

188 lines
6.1 KiB
Python

# !/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import torch
import torch.utils.data
import numpy as np
from utils.load_data import load_dataset
parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet32', 'imagenet64'],
metavar='DATASET',
help='Dataset choice.')
parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')
parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
help='how many batches to wait before logging training status')
parser.add_argument('--evaluate_interval_epochs', type=int, default=25,
help='Evaluate per how many epochs')
# optimization settings
parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
help='number of epochs to train (default: 2000)')
parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
help='number of early stopping epochs')
parser.add_argument('-bs', '--batch_size', type=int, default=10, metavar='BATCH_SIZE',
help='input batch size for training (default: 100)')
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
help='learning rate')
parser.add_argument('--warmup', type=int, default=10,
help='number of warmup epochs')
parser.add_argument('--data_augmentation_level', type=int, default=2,
help='data augmentation level')
parser.add_argument('--no_decode', action='store_true', default=False,
help='disables decoding')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
def encode_images(img, model, decode):
batchsize, img_c, img_h, img_w = img.size()
c, h, w = model.args.input_size
assert img_h == img_w and h == w
if img_h != h:
assert img_h % h == 0
steps = img_h // h
states = [[] for i in range(batchsize)]
state_sizes = [0 for i in range(batchsize)]
bpd = [0 for i in range(batchsize)]
error = 0
for j in range(steps):
for i in range(steps):
r = encode_patches(
img[:, :, j*h:(j+1)*h, i*w:(i+1)*w], model, decode)
for b in range(batchsize):
if r[0][b] is None:
states[b].append(None)
else:
states[b].extend(r[0][b])
state_sizes[b] += r[1][b]
bpd[b] += r[2][b] / steps**2
error += r[3]
return states, state_sizes, bpd, error
else:
return encode_patches(img, model, decode)
def encode_patches(imgs, model, decode):
batchsize, img_c, img_h, img_w = imgs.size()
c, h, w = model.args.input_size
assert img_h == h and img_w == w
states = model.encode(imgs)
bpd = model.forward(imgs)[1].cpu().numpy()
state_sizes = []
error = 0
for b in range(batchsize):
if states[b] is None:
# Using escape bit ;)
state_sizes += [8 * img_c * img_h * img_w + 1]
# Error remains unchanged.
print('Escaping, not encoding.')
else:
if decode:
x_recon = model.decode([states[b]])
error += torch.sum(
torch.abs(x_recon.int() - imgs[b].int())).item()
# Append state plus an escape bit
state_sizes += [32 * len(states[b]) + 1]
return states, state_sizes, bpd, error
def run(args, kwargs):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
args.snap_dir = snap_dir = \
'snapshots/discrete_logisticcifar10_flows_2_levels_3__2019-09-27_13_08_49/'
# ==================================================================================================================
# SNAPSHOTS
# ==================================================================================================================
# ==================================================================================================================
# LOAD DATA
# ==================================================================================================================
train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
final_model = torch.load(snap_dir + 'a.model')
if hasattr(final_model, 'module'):
final_model = final_model.module
final_model = final_model.cuda()
sizes = []
errors = []
bpds = []
import time
start = time.time()
t = 0
with torch.no_grad():
for data, _ in test_loader:
if args.cuda:
data = data.cuda()
state, state_sizes, bpd, error = \
encode_images(data, final_model, decode=not args.no_decode)
errors += [error]
bpds.extend(bpd)
sizes.extend(state_sizes)
t += len(data)
print(
'Examples: {}/{} bpd compression: {:.3f} error: {},'
' analytical bpd {:.3f}'.format(
t, len(test_loader.dataset),
np.mean(sizes) / np.prod(data.size()[1:]),
np.sum(errors),
np.mean(bpds)
))
if args.no_decode:
print('Not testing decoding.')
else:
print('Error: {}'.format(np.sum(errors)))
print('Took {:.3f} seconds / example'.format((time.time() - start) / t))
print('Final bpd: {:.3f} error: {}'.format(
np.mean(sizes) / np.prod(data.size()[1:]),
np.sum(errors)))
if __name__ == "__main__":
run(args, kwargs)