from __future__ import print_function import numpy as np import torch from utils.distributions import log_discretized_logistic, \ log_mixture_discretized_logistic, log_normal, log_discretized_normal, \ log_logistic, log_mixture_normal from models.backround import _round_straightthrough def compute_log_ps(pxs, xs, args): # Add likelihoods of intermediate representations. inverse_bin_width = 2.**args.n_bits log_pxs = [] for px, x in zip(pxs, xs): if args.variable_type == 'discrete': if args.distribution_type == 'logistic': log_px = log_discretized_logistic( x, *px, inverse_bin_width=inverse_bin_width) elif args.distribution_type == 'normal': log_px = log_discretized_normal( x, *px, inverse_bin_width=inverse_bin_width) elif args.variable_type == 'continuous': if args.distribution_type == 'logistic': log_px = log_logistic(x, *px) elif args.distribution_type == 'normal': log_px = log_normal(x, *px) elif args.distribution_type == 'steplogistic': x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width log_px = log_discretized_logistic( x, *px, inverse_bin_width=inverse_bin_width) log_pxs.append( torch.sum(log_px, dim=[1, 2, 3])) return log_pxs def compute_log_pz(pz, z, args): inverse_bin_width = 2.**args.n_bits if args.variable_type == 'discrete': if args.distribution_type == 'logistic': if args.n_mixtures == 1: log_pz = log_discretized_logistic( z, pz[0], pz[1], inverse_bin_width=inverse_bin_width) else: log_pz = log_mixture_discretized_logistic( z, pz[0], pz[1], pz[2], inverse_bin_width=inverse_bin_width) elif args.distribution_type == 'normal': log_pz = log_discretized_normal( z, *pz, inverse_bin_width=inverse_bin_width) elif args.variable_type == 'continuous': if args.distribution_type == 'logistic': log_pz = log_logistic(z, *pz) elif args.distribution_type == 'normal': if args.n_mixtures == 1: log_pz = log_normal(z, *pz) else: log_pz = log_mixture_normal(z, *pz) elif args.distribution_type == 'steplogistic': z = _round_straightthrough(z * 256.) / 256. log_pz = log_discretized_logistic(z, *pz) log_pz = torch.sum( log_pz, dim=[1, 2, 3]) return log_pz def compute_loss_function(pz, z, pys, ys, ldj, args): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ batch_size = z.size(0) # Get array loss, sum over batch loss_array, bpd_array, bpd_per_prior_array = \ compute_loss_array(pz, z, pys, ys, ldj, args) loss = torch.mean(loss_array) bpd = torch.mean(bpd_array).item() bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array] return loss, bpd, bpd_per_prior def convert_bpd(log_p, input_size): return -log_p / (np.prod(input_size) * np.log(2.)) def compute_loss_array(pz, z, pys, ys, ldj, args): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ bpd_per_prior = [] # Likelihood of final representation. log_pz = compute_log_pz(pz, z, args) bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size)) log_p = log_pz # Add likelihoods of intermediate representations. if ys: log_pys = compute_log_ps(pys, ys, args) for log_py in log_pys: log_p += log_py bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size)) log_p += ldj loss = -log_p bpd = convert_bpd(log_p.detach(), args.input_size) return loss, bpd, bpd_per_prior def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args): return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)