151 lines
4.2 KiB
Python
151 lines
4.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from models.utils import Base
|
|
|
|
|
|
class RoundStraightThrough(torch.autograd.Function):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
rounded = torch.round(input, out=None)
|
|
return rounded
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_input = grad_output.clone()
|
|
return grad_input
|
|
|
|
|
|
_round_straightthrough = RoundStraightThrough().apply
|
|
|
|
|
|
def _stacked_sigmoid(x, temperature, n_approx=3):
|
|
|
|
x_ = x - 0.5
|
|
rounded = torch.round(x_)
|
|
x_remainder = x_ - rounded
|
|
|
|
size = x_.size()
|
|
x_remainder = x_remainder.view(size + (1,))
|
|
|
|
translation = torch.arange(n_approx) - n_approx // 2
|
|
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
|
translation = translation.view([1] * len(size) + [len(translation)])
|
|
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
|
|
|
return out + rounded - (n_approx // 2)
|
|
|
|
|
|
class SmoothRound(Base):
|
|
def __init__(self):
|
|
self._temperature = None
|
|
self._n_approx = None
|
|
super().__init__()
|
|
self.hard_round = None
|
|
|
|
@property
|
|
def temperature(self):
|
|
return self._temperature
|
|
|
|
@temperature.setter
|
|
def temperature(self, value):
|
|
self._temperature = value
|
|
|
|
if self._temperature <= 0.05:
|
|
self._n_approx = 1
|
|
elif 0.05 < self._temperature < 0.13:
|
|
self._n_approx = 3
|
|
else:
|
|
self._n_approx = 5
|
|
|
|
def forward(self, x):
|
|
assert self._temperature is not None
|
|
assert self._n_approx is not None
|
|
assert self.hard_round is not None
|
|
|
|
if self.temperature <= 0.25:
|
|
h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
|
|
else:
|
|
h = x
|
|
|
|
if self.hard_round:
|
|
h = _round_straightthrough(h)
|
|
|
|
return h
|
|
|
|
|
|
class StochasticRound(Base):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.hard_round = None
|
|
|
|
def forward(self, x):
|
|
u = torch.rand_like(x)
|
|
|
|
h = x + u - 0.5
|
|
|
|
if self.hard_round:
|
|
h = _round_straightthrough(h)
|
|
|
|
return h
|
|
|
|
|
|
class BackRound(Base):
|
|
|
|
def __init__(self, args, inverse_bin_width):
|
|
"""
|
|
BackRound is an approximation to Round that allows for Backpropagation.
|
|
|
|
Approximate the round function using a sum of translated sigmoids.
|
|
The temperature determines how well the round function is approximated,
|
|
i.e., a lower temperature corresponds to a better approximation, at
|
|
the cost of more vanishing gradients.
|
|
|
|
BackRound supports the following settings:
|
|
* By setting hard to True and temperature > 0.25, BackRound
|
|
reduces to a round function with a straight through gradient
|
|
estimator
|
|
* When using 0 < temperature <= 0.25 and hard = True, the
|
|
output in the forward pass is equivalent to a round function, but the
|
|
gradient is approximated by the gradient of a sum of sigmoids.
|
|
* When using hard = False, the output is not constrained to integers.
|
|
* When temperature > 0.25 and hard = False, BackRound reduces to
|
|
the identity function.
|
|
|
|
Arguments
|
|
---------
|
|
temperature: float
|
|
Temperature used for stacked sigmoid approximated. If temperature
|
|
is greater than 0.25, the approximation reduces to the indentiy
|
|
function.
|
|
hard: bool
|
|
If hard is True, a (hard) round is applied before returning. The
|
|
gradient for this is approximated using the straight-through
|
|
estimator.
|
|
"""
|
|
super().__init__()
|
|
self.inverse_bin_width = inverse_bin_width
|
|
self.round_approx = args.round_approx
|
|
|
|
if args.round_approx == 'smooth':
|
|
self.round = SmoothRound()
|
|
elif args.round_approx == 'stochastic':
|
|
self.round = StochasticRound()
|
|
else:
|
|
raise ValueError
|
|
|
|
def forward(self, x):
|
|
if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
|
|
h = x * self.inverse_bin_width
|
|
|
|
h = self.round(h)
|
|
|
|
return h / self.inverse_bin_width
|
|
|
|
else:
|
|
raise ValueError
|