feat: initial for IDF

This commit is contained in:
Robin Meersman 2025-11-07 12:54:36 +01:00
commit ef4684ef39
27 changed files with 2830 additions and 0 deletions

View file

@ -0,0 +1,148 @@
from __future__ import print_function
import numpy as np
import torch
from utils.distributions import log_discretized_logistic, \
log_mixture_discretized_logistic, log_normal, log_discretized_normal, \
log_logistic, log_mixture_normal
from models.backround import _round_straightthrough
def compute_log_ps(pxs, xs, args):
# Add likelihoods of intermediate representations.
inverse_bin_width = 2.**args.n_bits
log_pxs = []
for px, x in zip(pxs, xs):
if args.variable_type == 'discrete':
if args.distribution_type == 'logistic':
log_px = log_discretized_logistic(
x, *px, inverse_bin_width=inverse_bin_width)
elif args.distribution_type == 'normal':
log_px = log_discretized_normal(
x, *px, inverse_bin_width=inverse_bin_width)
elif args.variable_type == 'continuous':
if args.distribution_type == 'logistic':
log_px = log_logistic(x, *px)
elif args.distribution_type == 'normal':
log_px = log_normal(x, *px)
elif args.distribution_type == 'steplogistic':
x = _round_straightthrough(x * inverse_bin_width) / inverse_bin_width
log_px = log_discretized_logistic(
x, *px, inverse_bin_width=inverse_bin_width)
log_pxs.append(
torch.sum(log_px, dim=[1, 2, 3]))
return log_pxs
def compute_log_pz(pz, z, args):
inverse_bin_width = 2.**args.n_bits
if args.variable_type == 'discrete':
if args.distribution_type == 'logistic':
if args.n_mixtures == 1:
log_pz = log_discretized_logistic(
z, pz[0], pz[1], inverse_bin_width=inverse_bin_width)
else:
log_pz = log_mixture_discretized_logistic(
z, pz[0], pz[1], pz[2],
inverse_bin_width=inverse_bin_width)
elif args.distribution_type == 'normal':
log_pz = log_discretized_normal(
z, *pz, inverse_bin_width=inverse_bin_width)
elif args.variable_type == 'continuous':
if args.distribution_type == 'logistic':
log_pz = log_logistic(z, *pz)
elif args.distribution_type == 'normal':
if args.n_mixtures == 1:
log_pz = log_normal(z, *pz)
else:
log_pz = log_mixture_normal(z, *pz)
elif args.distribution_type == 'steplogistic':
z = _round_straightthrough(z * 256.) / 256.
log_pz = log_discretized_logistic(z, *pz)
log_pz = torch.sum(
log_pz,
dim=[1, 2, 3])
return log_pz
def compute_loss_function(pz, z, pys, ys, ldj, args):
"""
Computes the cross entropy loss function while summing over batch dimension, not averaged!
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
:param z_mu: mean of z_0
:param z_var: variance of z_0
:param z_0: first stochastic latent variable
:param z_k: last stochastic latent variable
:param ldj: log det jacobian
:param args: global parameter settings
:param beta: beta for kl loss
:return: loss, ce, kl
"""
batch_size = z.size(0)
# Get array loss, sum over batch
loss_array, bpd_array, bpd_per_prior_array = \
compute_loss_array(pz, z, pys, ys, ldj, args)
loss = torch.mean(loss_array)
bpd = torch.mean(bpd_array).item()
bpd_per_prior = [torch.mean(x) for x in bpd_per_prior_array]
return loss, bpd, bpd_per_prior
def convert_bpd(log_p, input_size):
return -log_p / (np.prod(input_size) * np.log(2.))
def compute_loss_array(pz, z, pys, ys, ldj, args):
"""
Computes the cross entropy loss function while summing over batch dimension, not averaged!
:param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
:param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
:param z_mu: mean of z_0
:param z_var: variance of z_0
:param z_0: first stochastic latent variable
:param z_k: last stochastic latent variable
:param ldj: log det jacobian
:param args: global parameter settings
:param beta: beta for kl loss
:return: loss, ce, kl
"""
bpd_per_prior = []
# Likelihood of final representation.
log_pz = compute_log_pz(pz, z, args)
bpd_per_prior.append(convert_bpd(log_pz.detach(), args.input_size))
log_p = log_pz
# Add likelihoods of intermediate representations.
if ys:
log_pys = compute_log_ps(pys, ys, args)
for log_py in log_pys:
log_p += log_py
bpd_per_prior.append(convert_bpd(log_py.detach(), args.input_size))
log_p += ldj
loss = -log_p
bpd = convert_bpd(log_p.detach(), args.input_size)
return loss, bpd, bpd_per_prior
def calculate_loss(pz, z, pys, ys, ldj, loss_aux, args):
return compute_loss_function(pz, z, pys, ys, ldj, loss_aux, args)

View file

@ -0,0 +1,174 @@
from __future__ import print_function
import torch
from optimization.loss import calculate_loss
from utils.visual_evaluation import plot_reconstructions
import numpy as np
def train(epoch, train_loader, model, opt, args):
model.train()
train_loss = np.zeros(len(train_loader))
train_bpd = np.zeros(len(train_loader))
num_data = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, *args.input_size)
data = data.to(args.device)
opt.zero_grad()
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
loss = torch.mean(loss)
bpd = torch.mean(bpd)
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
loss.backward()
loss = loss.item()
train_loss[batch_idx] = loss
train_bpd[batch_idx] = bpd
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
opt.step()
num_data += len(data)
if batch_idx % args.log_interval == 0:
perc = 100. * batch_idx / len(train_loader)
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
for i in range(1, len(bpd_per_prior)):
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
if len(pz) == 3:
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
for i, py in enumerate(pys):
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
from utils.visual_evaluation import plot_images
import os
if not os.path.exists(args.snap_dir + 'training/'):
os.makedirs(args.snap_dir + 'training/')
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
return train_loss, train_bpd
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
model.eval()
loss_type = 'bpd'
def analyse(data_loader, plot=False):
bpds = []
batch_idx = 0
with torch.no_grad():
for data, _ in data_loader:
batch_idx += 1
if args.cuda:
data = data.cuda()
data = data.view(-1, *args.input_size)
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
model(data)
loss = torch.mean(loss).item()
batch_bpd = torch.mean(batch_bpd).item()
bpds.append(batch_bpd)
bpd = np.mean(bpds)
with torch.no_grad():
if not testing and plot:
x_sample = model_sample.sample(n=100)
try:
plot_reconstructions(
x_sample, bpd, loss_type, epoch, args)
except:
print('Not plotting')
return bpd
bpd_train = analyse(train_loader)
bpd_val = analyse(val_loader, plot=True)
with open(file, 'a') as ff:
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
epoch,
bpd_train,
bpd_val)
print(msg, file=ff)
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
bpd = bpd_val
file = None
# Compute log-likelihood
with torch.no_grad():
if testing:
test_data = val_loader.dataset.data_tensor
if args.cuda:
test_data = test_data.cuda()
print('Computing log-likelihood on test set')
model.eval()
log_likelihood = analyse(test_data)
else:
log_likelihood = None
nll_bpd = None
if file is None:
if testing:
print('====> Test set loss: {:.4f}'.format(loss))
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
(np.prod(args.input_size) * np.log(2.))))
else:
print('====> Validation set loss: {:.4f}'.format(loss))
print('====> Validation set bpd: {:.4f}'.format(bpd))
else:
with open(file, 'a') as ff:
if testing:
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
(np.prod(args.input_size) * np.log(2.))),
file=ff)
else:
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
file=ff)
if not testing:
return loss, bpd
else:
return log_likelihood, nll_bpd