feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.device, dtype=x.dtype)
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)