feat: initial for IDF
This commit is contained in:
commit
ef4684ef39
27 changed files with 2830 additions and 0 deletions
174
integer_discrete_flows/optimization/training.py
Normal file
174
integer_discrete_flows/optimization/training.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
from __future__ import print_function
|
||||
import torch
|
||||
|
||||
from optimization.loss import calculate_loss
|
||||
from utils.visual_evaluation import plot_reconstructions
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def train(epoch, train_loader, model, opt, args):
|
||||
model.train()
|
||||
train_loss = np.zeros(len(train_loader))
|
||||
train_bpd = np.zeros(len(train_loader))
|
||||
|
||||
num_data = 0
|
||||
|
||||
for batch_idx, (data, _) in enumerate(train_loader):
|
||||
data = data.view(-1, *args.input_size)
|
||||
|
||||
data = data.to(args.device)
|
||||
|
||||
opt.zero_grad()
|
||||
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
||||
|
||||
loss = torch.mean(loss)
|
||||
bpd = torch.mean(bpd)
|
||||
bpd_per_prior = [torch.mean(i) for i in bpd_per_prior]
|
||||
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
train_loss[batch_idx] = loss
|
||||
train_bpd[batch_idx] = bpd
|
||||
|
||||
ldj = torch.mean(ldj).item() / np.prod(args.input_size) / np.log(2)
|
||||
|
||||
opt.step()
|
||||
|
||||
num_data += len(data)
|
||||
|
||||
if batch_idx % args.log_interval == 0:
|
||||
perc = 100. * batch_idx / len(train_loader)
|
||||
|
||||
tmp = 'Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] \tLoss: {:11.6f}\tbpd: {:8.6f}\tbits ldj: {:8.6f}'
|
||||
print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, loss, bpd, ldj))
|
||||
|
||||
print('z min: {:8.3f}, max: {:8.3f}'.format(torch.min(z).item() * 256, torch.max(z).item() * 256))
|
||||
|
||||
print('z bpd: {:.3f}'.format(bpd_per_prior[0]))
|
||||
for i in range(1, len(bpd_per_prior)):
|
||||
print('y{} bpd: {:.3f}'.format(i-1, bpd_per_prior[i]))
|
||||
|
||||
print('pz mu', np.mean(pz[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
print('pz logs ', np.mean(pz[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
if len(pz) == 3:
|
||||
print('pz pi ', np.mean(pz[2].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
|
||||
for i, py in enumerate(pys):
|
||||
print('py{} mu '.format(i), np.mean(py[0].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
print('py{} logs '.format(i), np.mean(py[1].data.cpu().numpy(), axis=(0, 1, 2, 3)))
|
||||
|
||||
from utils.visual_evaluation import plot_images
|
||||
import os
|
||||
if not os.path.exists(args.snap_dir + 'training/'):
|
||||
os.makedirs(args.snap_dir + 'training/')
|
||||
|
||||
print('====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.format(
|
||||
epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
|
||||
|
||||
return train_loss, train_bpd
|
||||
|
||||
|
||||
def evaluate(train_loader, val_loader, model, model_sample, args, testing=False, file=None, epoch=0):
|
||||
model.eval()
|
||||
|
||||
loss_type = 'bpd'
|
||||
|
||||
def analyse(data_loader, plot=False):
|
||||
bpds = []
|
||||
batch_idx = 0
|
||||
with torch.no_grad():
|
||||
for data, _ in data_loader:
|
||||
batch_idx += 1
|
||||
|
||||
if args.cuda:
|
||||
data = data.cuda()
|
||||
|
||||
data = data.view(-1, *args.input_size)
|
||||
|
||||
loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = \
|
||||
model(data)
|
||||
loss = torch.mean(loss).item()
|
||||
batch_bpd = torch.mean(batch_bpd).item()
|
||||
|
||||
bpds.append(batch_bpd)
|
||||
|
||||
bpd = np.mean(bpds)
|
||||
|
||||
with torch.no_grad():
|
||||
if not testing and plot:
|
||||
x_sample = model_sample.sample(n=100)
|
||||
|
||||
try:
|
||||
plot_reconstructions(
|
||||
x_sample, bpd, loss_type, epoch, args)
|
||||
except:
|
||||
print('Not plotting')
|
||||
|
||||
return bpd
|
||||
|
||||
bpd_train = analyse(train_loader)
|
||||
bpd_val = analyse(val_loader, plot=True)
|
||||
|
||||
with open(file, 'a') as ff:
|
||||
msg = 'epoch {}\ttrain bpd {:.3f}\tval bpd {:.3f}\t'.format(
|
||||
epoch,
|
||||
bpd_train,
|
||||
bpd_val)
|
||||
print(msg, file=ff)
|
||||
|
||||
loss = bpd_val * np.prod(args.input_size) * np.log(2.)
|
||||
bpd = bpd_val
|
||||
|
||||
file = None
|
||||
|
||||
# Compute log-likelihood
|
||||
with torch.no_grad():
|
||||
if testing:
|
||||
test_data = val_loader.dataset.data_tensor
|
||||
|
||||
if args.cuda:
|
||||
test_data = test_data.cuda()
|
||||
|
||||
print('Computing log-likelihood on test set')
|
||||
|
||||
model.eval()
|
||||
|
||||
log_likelihood = analyse(test_data)
|
||||
|
||||
else:
|
||||
log_likelihood = None
|
||||
nll_bpd = None
|
||||
|
||||
if file is None:
|
||||
if testing:
|
||||
print('====> Test set loss: {:.4f}'.format(loss))
|
||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood))
|
||||
|
||||
print('====> Test set bpd (elbo): {:.4f}'.format(bpd))
|
||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood/
|
||||
(np.prod(args.input_size) * np.log(2.))))
|
||||
|
||||
else:
|
||||
print('====> Validation set loss: {:.4f}'.format(loss))
|
||||
print('====> Validation set bpd: {:.4f}'.format(bpd))
|
||||
else:
|
||||
with open(file, 'a') as ff:
|
||||
if testing:
|
||||
print('====> Test set loss: {:.4f}'.format(loss), file=ff)
|
||||
print('====> Test set log-likelihood: {:.4f}'.format(log_likelihood), file=ff)
|
||||
|
||||
print('====> Test set bpd: {:.4f}'.format(bpd), file=ff)
|
||||
print('====> Test set bpd (log-likelihood): {:.4f}'.format(log_likelihood /
|
||||
(np.prod(args.input_size) * np.log(2.))),
|
||||
file=ff)
|
||||
|
||||
else:
|
||||
print('====> Validation set loss: {:.4f}'.format(loss), file=ff)
|
||||
print('====> Validation set bpd: {:.4f}'.format(loss / (np.prod(args.input_size) * np.log(2.))),
|
||||
file=ff)
|
||||
|
||||
if not testing:
|
||||
return loss, bpd
|
||||
else:
|
||||
return log_likelihood, nll_bpd
|
||||
Reference in a new issue