feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
|
|
@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
|||
|
||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||
bin_locations = bin_locations.float() * bin_width
|
||||
bin_locations = bin_locations.to(device=pz[0].device)
|
||||
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
||||
|
||||
pz = [param[:, :, :, :, None] for param in pz]
|
||||
cdf = cdf_fn(
|
||||
|
|
@ -86,7 +86,7 @@ def decode_sample(
|
|||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||
state = rans.unflatten(state)
|
||||
|
||||
device = pz[0].device
|
||||
device = pz[0].DEVICE
|
||||
size = pz[0].size()[0:4]
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
|
|
|||
Reference in a new issue