from __future__ import print_function import torch import torch.utils.data import torch.nn.functional as F import numpy as np import math MIN_EPSILON = 1e-5 MAX_EPSILON = 1.-1e-5 PI = math.pi def log_min_exp(a, b, epsilon=1e-8): """ Computes the log of exp(a) - exp(b) in a (more) numerically stable fashion. Using: log(exp(a) - exp(b)) c + log(exp(a-c) - exp(b-c)) a + log(1 - exp(b-a)) And note that we assume b < a always. """ y = a + torch.log(1 - torch.exp(b - a) + epsilon) return y def log_normal(x, mean, logvar): logp = -0.5 * logvar logp += -0.5 * np.log(2 * PI) logp += -0.5 * (x - mean) * (x - mean) / torch.exp(logvar) return logp def log_mixture_normal(x, mean, logvar, pi): x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) logp_mixtures = log_normal(x, mean, logvar) logp = torch.log(torch.sum(pi * torch.exp(logp_mixtures), dim=-1) + 1e-8) return logp def sample_normal(mean, logvar): y = torch.randn_like(mean) x = torch.exp(0.5 * logvar) * y + mean return x def sample_mixture_normal(mean, logvar, pi): b, c, h, w, n_mixtures = tuple(map(int, pi.size())) pi = pi.view(b * c * h * w, n_mixtures) sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) # Select mixture params mean = mean.view(b * c * h * w, n_mixtures) mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) logvar = logvar.view(b * c * h * w, n_mixtures) logvar = logvar[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) y = sample_normal(mean, logvar) return y def log_logistic(x, mean, logscale): """ pdf = sigma([x - mean] / scale) * [1 - sigma(...)] * 1/scale """ scale = torch.exp(logscale) u = (x - mean) / scale logp = F.logsigmoid(u) + F.logsigmoid(-u) - logscale return logp def sample_logistic(mean, logscale): y = torch.rand_like(mean) x = torch.exp(logscale) * torch.log(y / (1 - y)) + mean return x def log_discretized_logistic(x, mean, logscale, inverse_bin_width): scale = torch.exp(logscale) logp = log_min_exp( F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale), F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale)) return logp def discretized_logistic_cdf(x, mean, logscale, inverse_bin_width): scale = torch.exp(logscale) cdf = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) return cdf def sample_discretized_logistic(mean, logscale, inverse_bin_width): x = sample_logistic(mean, logscale) x = torch.round(x * inverse_bin_width) / inverse_bin_width return x def normal_cdf(value, loc, std): return 0.5 * (1 + torch.erf((value - loc) * std.reciprocal() / math.sqrt(2))) def log_discretized_normal(x, mean, logvar, inverse_bin_width): std = torch.exp(0.5 * logvar) log_p = torch.log(normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) + 1e-7) return log_p def log_mixture_discretized_normal(x, mean, logvar, pi, inverse_bin_width): std = torch.exp(0.5 * logvar) x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) p = normal_cdf(x + 0.5 / inverse_bin_width, mean, std) - normal_cdf(x - 0.5 / inverse_bin_width, mean, std) p = torch.sum(p * pi, dim=-1) logp = torch.log(p + 1e-8) return logp def sample_discretized_normal(mean, logvar, inverse_bin_width): y = torch.randn_like(mean) x = torch.exp(0.5 * logvar) * y + mean x = torch.round(x * inverse_bin_width) / inverse_bin_width return x def log_mixture_discretized_logistic(x, mean, logscale, pi, inverse_bin_width): scale = torch.exp(logscale) x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1) p = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) \ - torch.sigmoid((x - 0.5 / inverse_bin_width - mean) / scale) p = torch.sum(p * pi, dim=-1) logp = torch.log(p + 1e-8) return logp def mixture_discretized_logistic_cdf(x, mean, logscale, pi, inverse_bin_width): scale = torch.exp(logscale) x = x[..., None] cdfs = torch.sigmoid((x + 0.5 / inverse_bin_width - mean) / scale) cdf = torch.sum(cdfs * pi, dim=-1) return cdf def sample_mixture_discretized_logistic(mean, logs, pi, inverse_bin_width): # Sample mixtures b, c, h, w, n_mixtures = tuple(map(int, pi.size())) pi = pi.view(b * c * h * w, n_mixtures) sampled_pi = torch.multinomial(pi, num_samples=1).view(-1) # Select mixture params mean = mean.view(b * c * h * w, n_mixtures) mean = mean[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) logs = logs.view(b * c * h * w, n_mixtures) logs = logs[torch.arange(b*c*h*w), sampled_pi].view(b, c, h, w) y = torch.rand_like(mean) x = torch.exp(logs) * torch.log(y / (1 - y)) + mean x = torch.round(x * inverse_bin_width) / inverse_bin_width return x def log_multinomial(logits, targets): return -F.cross_entropy(logits, targets, reduction='none') def sample_multinomial(logits): b, n_categories, c, h, w = logits.size() logits = logits.permute(0, 2, 3, 4, 1) p = F.softmax(logits, dim=-1) p = p.view(b * c * h * w, n_categories) x = torch.multinomial(p, num_samples=1).view(b, c, h, w) return x