feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
|
|
@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
|||
|
||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||
bin_locations = bin_locations.float() * bin_width
|
||||
bin_locations = bin_locations.to(device=pz[0].device)
|
||||
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
||||
|
||||
pz = [param[:, :, :, :, None] for param in pz]
|
||||
cdf = cdf_fn(
|
||||
|
|
@ -86,7 +86,7 @@ def decode_sample(
|
|||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||
state = rans.unflatten(state)
|
||||
|
||||
device = pz[0].device
|
||||
device = pz[0].DEVICE
|
||||
size = pz[0].size()[0:4]
|
||||
|
||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ def run(args, kwargs):
|
|||
import models.Model as Model
|
||||
|
||||
model = Model.Model(args)
|
||||
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model.set_temperature(args.temperature)
|
||||
model.enable_hard_round(args.hard_round)
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ def run(args, kwargs):
|
|||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = torch.nn.DataParallel(model, dim=0)
|
||||
|
||||
model.to(args.device)
|
||||
model.to(args.DEVICE)
|
||||
|
||||
def lr_lambda(epoch):
|
||||
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
|
|||
x_remainder = x_remainder.view(size + (1,))
|
||||
|
||||
translation = torch.arange(n_approx) - n_approx // 2
|
||||
translation = translation.to(device=x.device, dtype=x.dtype)
|
||||
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
||||
translation = translation.view([1] * len(size) + [len(translation)])
|
||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def train(epoch, train_loader, model, opt, args):
|
|||
for batch_idx, (data, _) in enumerate(train_loader):
|
||||
data = data.view(-1, *args.input_size)
|
||||
|
||||
data = data.to(args.device)
|
||||
data = data.to(args.DEVICE)
|
||||
|
||||
opt.zero_grad()
|
||||
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
||||
|
|
|
|||
Reference in a new issue