feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
bin_locations = bin_locations.float() * bin_width
bin_locations = bin_locations.to(device=pz[0].device)
bin_locations = bin_locations.to(device=pz[0].DEVICE)
pz = [param[:, :, :, :, None] for param in pz]
cdf = cdf_fn(
@ -86,7 +86,7 @@ def decode_sample(
state, pz, variable_type, distribution_type, bin_width=1./256):
state = rans.unflatten(state)
device = pz[0].device
device = pz[0].DEVICE
size = pz[0].size()[0:4]
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)

View file

@ -190,7 +190,7 @@ def run(args, kwargs):
import models.Model as Model
model = Model.Model(args)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.set_temperature(args.temperature)
model.enable_hard_round(args.hard_round)
@ -208,7 +208,7 @@ def run(args, kwargs):
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, dim=0)
model.to(args.device)
model.to(args.DEVICE)
def lr_lambda(epoch):
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)

View file

@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.device, dtype=x.dtype)
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)

View file

@ -17,7 +17,7 @@ def train(epoch, train_loader, model, opt, args):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, *args.input_size)
data = data.to(args.device)
data = data.to(args.DEVICE)
opt.zero_grad()
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)