feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
bin_locations = bin_locations.float() * bin_width
bin_locations = bin_locations.to(device=pz[0].device)
bin_locations = bin_locations.to(device=pz[0].DEVICE)
pz = [param[:, :, :, :, None] for param in pz]
cdf = cdf_fn(
@ -86,7 +86,7 @@ def decode_sample(
state, pz, variable_type, distribution_type, bin_width=1./256):
state = rans.unflatten(state)
device = pz[0].device
device = pz[0].DEVICE
size = pz[0].size()[0:4]
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)