feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
|
|
@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
|
|||
x_remainder = x_remainder.view(size + (1,))
|
||||
|
||||
translation = torch.arange(n_approx) - n_approx // 2
|
||||
translation = translation.to(device=x.device, dtype=x.dtype)
|
||||
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
||||
translation = translation.view([1] * len(size) + [len(translation)])
|
||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||
|
||||
|
|
|
|||
Reference in a new issue