feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
37
integer_discrete_flows/utils/visual_evaluation.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
|
||||
|
||||
def plot_reconstructions(recon_mean, loss, loss_type, epoch, args):
|
||||
if epoch == 1:
|
||||
if not os.path.exists(args.snap_dir + 'reconstruction/'):
|
||||
os.makedirs(args.snap_dir + 'reconstruction/')
|
||||
if loss_type == 'bpd':
|
||||
fname = str(epoch) + '_bpd_%5.3f' % loss
|
||||
elif loss_type == 'elbo':
|
||||
fname = str(epoch) + '_elbo_%6.4f' % loss
|
||||
plot_images(args, recon_mean.data.cpu().numpy()[:100], args.snap_dir + 'reconstruction/', fname)
|
||||
|
||||
|
||||
def plot_images(args, x_sample, dir, file_name, size_x=10, size_y=10):
|
||||
batch, channels, height, width = x_sample.shape
|
||||
|
||||
print(x_sample.shape)
|
||||
|
||||
mosaic = np.zeros((height * size_y, width * size_x, channels))
|
||||
|
||||
for j in range(size_y):
|
||||
for i in range(size_x):
|
||||
idx = j * size_x + i
|
||||
|
||||
image = x_sample[idx]
|
||||
|
||||
mosaic[j*height:(j+1)*height, i*height:(i+1)*height] = \
|
||||
image.transpose(1, 2, 0)
|
||||
|
||||
# Remove channel for BW images
|
||||
mosaic = mosaic.squeeze()
|
||||
|
||||
imageio.imwrite(dir + file_name + '.png', mosaic)
|
||||
Reference in a new issue