This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/models/backround.py
2025-11-25 20:20:08 +01:00

151 lines
4.2 KiB
Python

import torch
import torch.nn.functional as F
import numpy as np
from models.utils import Base
class RoundStraightThrough(torch.autograd.Function):
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, input):
rounded = torch.round(input, out=None)
return rounded
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
_round_straightthrough = RoundStraightThrough().apply
def _stacked_sigmoid(x, temperature, n_approx=3):
x_ = x - 0.5
rounded = torch.round(x_)
x_remainder = x_ - rounded
size = x_.size()
x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
return out + rounded - (n_approx // 2)
class SmoothRound(Base):
def __init__(self):
self._temperature = None
self._n_approx = None
super().__init__()
self.hard_round = None
@property
def temperature(self):
return self._temperature
@temperature.setter
def temperature(self, value):
self._temperature = value
if self._temperature <= 0.05:
self._n_approx = 1
elif 0.05 < self._temperature < 0.13:
self._n_approx = 3
else:
self._n_approx = 5
def forward(self, x):
assert self._temperature is not None
assert self._n_approx is not None
assert self.hard_round is not None
if self.temperature <= 0.25:
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
else:
h = x
if self.hard_round:
h = _round_straightthrough(h)
return h
class StochasticRound(Base):
def __init__(self):
super().__init__()
self.hard_round = None
def forward(self, x):
u = torch.rand_like(x)
h = x + u - 0.5
if self.hard_round:
h = _round_straightthrough(h)
return h
class BackRound(Base):
def __init__(self, args, inverse_bin_width):
"""
BackRound is an approximation to Round that allows for Backpropagation.
Approximate the round function using a sum of translated sigmoids.
The temperature determines how well the round function is approximated,
i.e., a lower temperature corresponds to a better approximation, at
the cost of more vanishing gradients.
BackRound supports the following settings:
* By setting hard to True and temperature > 0.25, BackRound
reduces to a round function with a straight through gradient
estimator
* When using 0 < temperature <= 0.25 and hard = True, the
output in the forward pass is equivalent to a round function, but the
gradient is approximated by the gradient of a sum of sigmoids.
* When using hard = False, the output is not constrained to integers.
* When temperature > 0.25 and hard = False, BackRound reduces to
the identity function.
Arguments
---------
temperature: float
Temperature used for stacked sigmoid approximated. If temperature
is greater than 0.25, the approximation reduces to the indentiy
function.
hard: bool
If hard is True, a (hard) round is applied before returning. The
gradient for this is approximated using the straight-through
estimator.
"""
super().__init__()
self.inverse_bin_width = inverse_bin_width
self.round_approx = args.round_approx
if args.round_approx == 'smooth':
self.round = SmoothRound()
elif args.round_approx == 'stochastic':
self.round = StochasticRound()
else:
raise ValueError
def forward(self, x):
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
h = x * self.inverse_bin_width
h = self.round(h)
return h / self.inverse_bin_width
else:
raise ValueError