feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
0
integer_discrete_flows/utils/__init__.py
Normal file
0
integer_discrete_flows/utils/__init__.py
Normal file
209
integer_discrete_flows/utils/distributions.py
Normal file
209
integer_discrete_flows/utils/distributions.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
MIN_EPSILON = 1e-5
|
||||
MAX_EPSILON = 1.-1e-5
|
||||
|
||||
|
||||
PI = math.pi
|
||||
|
||||
|
||||
def log_min_exp(a, b, epsilon=1e-8):
|
||||
"""
|
||||
Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion.
|
||||
Using:
|
||||
log(exp(a) - exp(b))
|
||||
c + log(exp(a-c) - exp(b-c))
|
||||
a + log(1 - exp(b-a))
|
||||
And note that we assume b < a always.
|
||||
"""
|
||||
y = a + torch.log(1 - torch.exp(b - a) + epsilon)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def log_normal(x, mean, logvar):
|
||||
logp = -0.5 * logvar
|
||||
logp += -0.5 * np.log(2 * PI)
|
||||
logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar)
|
||||
return logp
|
||||
|
||||
|
||||
def log_mixture_normal(x, mean, logvar, pi):
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
logp_mixtures = log_normal(x, mean, logvar)
|
||||
|
||||
logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_normal(mean, logvar):
|
||||
y = torch.randn_like(mean)
|
||||
|
||||
x = torch.exp(0.5 * logvar) * y + mean
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sample_mixture_normal(mean, logvar, pi):
|
||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||
pi = pi.view(b * c * h * w, n_mixtures)
|
||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||
|
||||
# Select mixture params
|
||||
mean = mean.view(b * c * h * w, n_mixtures)
|
||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
logvar = logvar.view(b * c * h * w, n_mixtures)
|
||||
logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
|
||||
y = sample_normal(mean, logvar)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def log_logistic(x, mean, logscale):
|
||||
"""
|
||||
pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale
|
||||
"""
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
u = (x - mean) / scale
|
||||
|
||||
logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_logistic(mean, logscale):
|
||||
y = torch.rand_like(mean)
|
||||
|
||||
x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
logp = log_min_exp(
|
||||
F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
|
||||
F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
def sample_discretized_logistic(mean, logscale, inverse_bin_width):
|
||||
x = sample_logistic(mean, logscale)
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
return x
|
||||
|
||||
|
||||
def normal_cdf(value, loc, std):
|
||||
return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2)))
|
||||
|
||||
|
||||
def log_discretized_normal(x, mean, logvar, inverse_bin_width):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7)
|
||||
|
||||
return log_p
|
||||
|
||||
|
||||
def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std)
|
||||
|
||||
p = torch.sum(p * pi, dim=-1)
|
||||
|
||||
logp = torch.log(p + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def sample_discretized_normal(mean, logvar, inverse_bin_width):
|
||||
y = torch.randn_like(mean)
|
||||
|
||||
x = torch.exp(0.5 * logvar) * y + mean
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)
|
||||
|
||||
p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \
|
||||
- torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
p = torch.sum(p * pi, dim=-1)
|
||||
|
||||
logp = torch.log(p + 1e-8)
|
||||
|
||||
return logp
|
||||
|
||||
|
||||
def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width):
|
||||
scale = torch.exp(logscale)
|
||||
|
||||
x = x[..., None]
|
||||
|
||||
cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale)
|
||||
|
||||
cdf = torch.sum(cdfs * pi, dim=-1)
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width):
|
||||
# Sample mixtures
|
||||
b, c, h, w, n_mixtures = tuple(map(int, pi.size()))
|
||||
pi = pi.view(b * c * h * w, n_mixtures)
|
||||
sampled_pi = torch.multinomial(pi, num_samples=1).view(-1)
|
||||
|
||||
# Select mixture params
|
||||
mean = mean.view(b * c * h * w, n_mixtures)
|
||||
mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
logs = logs.view(b * c * h * w, n_mixtures)
|
||||
logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w)
|
||||
|
||||
y = torch.rand_like(mean)
|
||||
x = torch.exp(logs) * torch.log(y / (1 - y)) + mean
|
||||
|
||||
x = torch.round(x * inverse_bin_width) / inverse_bin_width
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def log_multinomial(logits, targets):
|
||||
return -F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
|
||||
def sample_multinomial(logits):
|
||||
b, n_categories, c, h, w = logits.size()
|
||||
logits = logits.permute(0, 2, 3, 4, 1)
|
||||
p = F.softmax(logits, dim=-1)
|
||||
p = p.view(b * c * h * w, n_categories)
|
||||
x = torch.multinomial(p, num_samples=1).view(b, c, h, w)
|
||||
return x
|
||||
264
integer_discrete_flows/utils/load_data.py
Normal file
264
integer_discrete_flows/utils/load_data.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data_utils
|
||||
import pickle
|
||||
from scipy.io import loadmat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as vf
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
import os.path
|
||||
from os.path import join
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
|
||||
class ToTensorNoNorm():
|
||||
def __call__(self, X_i):
|
||||
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
||||
|
||||
|
||||
class PadToMultiple(object):
|
||||
def __init__(self, multiple, fill=0, padding_mode='constant'):
|
||||
assert isinstance(multiple, numbers.Number)
|
||||
assert isinstance(fill, (numbers.Number, str, tuple))
|
||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
||||
|
||||
self.multiple = multiple
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be padded.
|
||||
Returns:
|
||||
PIL Image: Padded image.
|
||||
"""
|
||||
w, h = img.size
|
||||
m = self.multiple
|
||||
nw = (w // m + int((w % m) != 0)) * m
|
||||
nh = (h // m + int((h % m) != 0)) * m
|
||||
padw = nw - w
|
||||
padh = nh - h
|
||||
|
||||
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
|
||||
format(self.mulitple, self.fill, self.padding_mode)
|
||||
|
||||
|
||||
class CustomTensorDataset(Dataset):
|
||||
"""Dataset wrapping tensors.
|
||||
|
||||
Each sample will be retrieved by indexing tensors along the first dimension.
|
||||
|
||||
Arguments:
|
||||
*tensors (Tensor): tensors that have the same size of the first dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, *tensors, transform=None):
|
||||
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
||||
self.tensors = tensors
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
from PIL import Image
|
||||
|
||||
X, y = self.tensors
|
||||
X_i, y_i, = X[index], y[index]
|
||||
|
||||
if self.transform:
|
||||
X_i = self.transform(X_i)
|
||||
X_i = torch.from_numpy(np.array(X_i, copy=False))
|
||||
X_i = X_i.permute(2, 0, 1)
|
||||
|
||||
return X_i, y_i
|
||||
|
||||
def __len__(self):
|
||||
return self.tensors[0].size(0)
|
||||
|
||||
|
||||
def load_cifar10(args, **kwargs):
|
||||
# set args
|
||||
args.input_size = [3, 32, 32]
|
||||
args.input_type = 'continuous'
|
||||
args.dynamic_binarization = False
|
||||
|
||||
from keras.datasets import cifar10
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
||||
x_train = x_train.transpose(0, 3, 1, 2)
|
||||
x_test = x_test.transpose(0, 3, 1, 2)
|
||||
|
||||
import math
|
||||
|
||||
if args.data_augmentation_level == 2:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
|
||||
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
||||
transforms.CenterCrop(32)
|
||||
])
|
||||
elif args.data_augmentation_level == 1:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
else:
|
||||
data_transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
])
|
||||
|
||||
x_val = x_train[-10000:]
|
||||
y_val = y_train[-10000:]
|
||||
|
||||
x_train = x_train[:-10000]
|
||||
y_train = y_train[:-10000]
|
||||
|
||||
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
|
||||
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
|
||||
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
|
||||
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||
|
||||
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
|
||||
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
def extract_tar(tarpath):
|
||||
assert tarpath.endswith('.tar')
|
||||
|
||||
startdir = tarpath[:-4] + '/'
|
||||
|
||||
if os.path.exists(startdir):
|
||||
return startdir
|
||||
|
||||
print('Extracting', tarpath)
|
||||
|
||||
with tarfile.open(name=tarpath) as tar:
|
||||
t = 0
|
||||
done = False
|
||||
while not done:
|
||||
path = join(startdir, 'images{}'.format(t))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
print(path)
|
||||
|
||||
for i in range(50000):
|
||||
member = tar.next()
|
||||
|
||||
if member is None:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Skip directories
|
||||
while member.isdir():
|
||||
member = tar.next()
|
||||
if member is None:
|
||||
done = True
|
||||
break
|
||||
|
||||
member.name = member.name.split('/')[-1]
|
||||
|
||||
tar.extract(member, path=path)
|
||||
|
||||
t += 1
|
||||
|
||||
return startdir
|
||||
|
||||
|
||||
def load_imagenet(resolution, args, **kwargs):
|
||||
assert resolution == 32 or resolution == 64
|
||||
|
||||
args.input_size = [3, resolution, resolution]
|
||||
|
||||
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
|
||||
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
|
||||
|
||||
trainpath = extract_tar(trainpath)
|
||||
valpath = extract_tar(valpath)
|
||||
|
||||
data_transform = transforms.Compose([
|
||||
ToTensorNoNorm()
|
||||
])
|
||||
|
||||
print('Starting loading ImageNet')
|
||||
|
||||
imagenet_data = torchvision.datasets.ImageFolder(
|
||||
trainpath,
|
||||
transform=data_transform)
|
||||
|
||||
print('Number of data images', len(imagenet_data))
|
||||
|
||||
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
|
||||
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
|
||||
|
||||
train_dataset = torch.utils.data.dataset.Subset(
|
||||
imagenet_data, train_idcs)
|
||||
val_dataset = torch.utils.data.dataset.Subset(
|
||||
imagenet_data, val_idcs)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
**kwargs)
|
||||
|
||||
test_dataset = torchvision.datasets.ImageFolder(
|
||||
valpath,
|
||||
transform=data_transform)
|
||||
|
||||
print('Number of val images:', len(test_dataset))
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
**kwargs)
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
def load_dataset(args, **kwargs):
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
|
||||
elif args.dataset == 'imagenet32':
|
||||
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
|
||||
elif args.dataset == 'imagenet64':
|
||||
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
|
||||
else:
|
||||
raise Exception('Wrong name of the dataset!')
|
||||
|
||||
return train_loader, val_loader, test_loader, args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class Args():
|
||||
def __init__(self):
|
||||
self.batch_size = 128
|
||||
train_loader, val_loader, test_loader, args = load_imagenet32(Args())
|
||||
56
integer_discrete_flows/utils/log_likelihood.py
Normal file
56
integer_discrete_flows/utils/log_likelihood.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from scipy.misc import logsumexp
|
||||
from optimization.loss import calculate_loss_array
|
||||
|
||||
|
||||
def calculate_likelihood(X, model, args, S=5000, MB=500):
|
||||
|
||||
# set auxiliary variables for number of training and test sets
|
||||
N_test = X.size(0)
|
||||
|
||||
X = X.view(-1, *args.input_size)
|
||||
|
||||
likelihood_test = []
|
||||
|
||||
if S <= MB:
|
||||
R = 1
|
||||
else:
|
||||
R = S // MB
|
||||
S = MB
|
||||
|
||||
for j in range(N_test):
|
||||
if j % 100 == 0:
|
||||
print('Progress: {:.2f}%'.format(j / (1. * N_test) * 100))
|
||||
|
||||
x_single = X[j].unsqueeze(0)
|
||||
|
||||
a = []
|
||||
for r in range(0, R):
|
||||
# Repeat it for all training points
|
||||
x = x_single.expand(S, *x_single.size()[1:]).contiguous()
|
||||
|
||||
x_mean, z_mu, z_var, ldj, z0, zk = model(x)
|
||||
|
||||
a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args)
|
||||
|
||||
a.append(-a_tmp.cpu().data.numpy())
|
||||
|
||||
# calculate max
|
||||
a = np.asarray(a)
|
||||
a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
|
||||
likelihood_x = logsumexp(a)
|
||||
likelihood_test.append(likelihood_x - np.log(len(a)))
|
||||
|
||||
likelihood_test = np.array(likelihood_test)
|
||||
|
||||
nll = -np.mean(likelihood_test)
|
||||
|
||||
if args.input_type == 'multinomial':
|
||||
bpd = nll/(np.prod(args.input_size) * np.log(2.))
|
||||
elif args.input_type == 'binary':
|
||||
bpd = 0.
|
||||
else:
|
||||
raise ValueError('invalid input type!')
|
||||
|
||||
return nll, bpd
|
||||
104
integer_discrete_flows/utils/plotting.py
Normal file
104
integer_discrete_flows/utils/plotting.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
# noninteractive background
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None):
|
||||
"""
|
||||
Plots train_loss and validation loss as a function of optimization iteration
|
||||
:param train_loss: np.array of train_loss (1D or 2D)
|
||||
:param validation_loss: np.array of validation loss (1D or 2D)
|
||||
:param fname: output file name
|
||||
:param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied
|
||||
accross training curves.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
plt.close()
|
||||
|
||||
matplotlib.rcParams.update({'font.size': 14})
|
||||
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||||
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||||
|
||||
if len(train_loss.shape) == 1:
|
||||
# Single training curve
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||
figsize = (6, 4)
|
||||
|
||||
if train_loss.shape[0] == validation_loss.shape[0]:
|
||||
# validation score evaluated every iteration
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||
|
||||
elif train_loss.shape[0] % validation_loss.shape[0] == 0:
|
||||
# validation score evaluated every epoch
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss, '-', lw=2., color='black', label='train')
|
||||
|
||||
x = np.arange(validation_loss.shape[0])
|
||||
x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0]
|
||||
ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val')
|
||||
else:
|
||||
raise ValueError('Length of train_loss and validation_loss must be equal or divisible')
|
||||
|
||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||
ax.set_ylim([miny, maxy])
|
||||
|
||||
elif len(train_loss.shape) == 2:
|
||||
# Multiple training curves
|
||||
|
||||
cmap = plt.cm.brg
|
||||
|
||||
cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0])
|
||||
scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
|
||||
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1)
|
||||
figsize = (6, 4)
|
||||
|
||||
if labels is None:
|
||||
labels = ['%d' % i for i in range(train_loss.shape[0])]
|
||||
|
||||
if train_loss.shape[1] == validation_loss.shape[1]:
|
||||
for i in range(train_loss.shape[0]):
|
||||
color_val = scalarMap.to_rgba(i)
|
||||
|
||||
# validation score evaluated every iteration
|
||||
x = np.arange(train_loss.shape[0])
|
||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||
ax.plot(x, validation_loss[i], '--', lw=2., color=color_val)
|
||||
|
||||
elif train_loss.shape[1] % validation_loss.shape[1] == 0:
|
||||
for i in range(train_loss.shape[0]):
|
||||
color_val = scalarMap.to_rgba(i)
|
||||
|
||||
# validation score evaluated every epoch
|
||||
x = np.arange(train_loss.shape[1])
|
||||
ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i])
|
||||
|
||||
x = np.arange(validation_loss.shape[1])
|
||||
x = (x+1) * train_loss.shape[1] / validation_loss.shape[1]
|
||||
ax.plot(x, validation_loss[i], '-', lw=2., color=color_val)
|
||||
|
||||
miny = np.minimum(validation_loss.min(), train_loss.min()) - 20.
|
||||
maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30.
|
||||
ax.set_ylim([miny, maxy])
|
||||
|
||||
else:
|
||||
raise ValueError('train_loss and validation_loss must be 1D or 2D arrays')
|
||||
|
||||
ax.set_xlabel('iteration')
|
||||
ax.set_ylabel('loss')
|
||||
plt.title('Training and validation loss')
|
||||
|
||||
fig.set_size_inches(figsize)
|
||||
fig.subplots_adjust(hspace=0.1)
|
||||
plt.savefig(fname, bbox_inches='tight')
|
||||
|
||||
plt.close()
|
||||
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
|
||||
|
||||
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
|
||||
if epoch == 1:
|
||||
if not os.path.exists(args.snap_dir + 'reconstruction/'):
|
||||
os.makedirs(args.snap_dir + 'reconstruction/')
|
||||
if loss_type == 'bpd':
|
||||
fname = str(epoch) + '_bpd_%5.3f' % loss
|
||||
elif loss_type == 'elbo':
|
||||
fname = str(epoch) + '_elbo_%6.4f' % loss
|
||||
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
|
||||
|
||||
|
||||
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
|
||||
batch, channels, height, width = x_sample.shape
|
||||
|
||||
print(x_sample.shape)
|
||||
|
||||
mosaic = np.zeros((height * size_y, width * size_x, channels))
|
||||
|
||||
for j in range(size_y):
|
||||
for i in range(size_x):
|
||||
idx = j * size_x + i
|
||||
|
||||
image = x_sample[idx]
|
||||
|
||||
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
|
||||
image.transpose(1, 2, 0)
|
||||
|
||||
# Remove channel for BW images
|
||||
mosaic = mosaic.squeeze()
|
||||
|
||||
imageio.imwrite(dir + file_name + '.png', mosaic)
|
||||
Reference in a new issue