feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
|
|
@ -176,9 +176,9 @@ class RelMultiHeadAttn(nn.Module):
|
|||
def _shift(self, x, qlen, klen, mask, left=False):
|
||||
if qlen > 1:
|
||||
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
||||
device=x.device, dtype=x.dtype)
|
||||
device=x.DEVICE, dtype=x.dtype)
|
||||
else:
|
||||
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
|
||||
zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
|
||||
|
||||
if left:
|
||||
mask = mask.flip(1)
|
||||
|
|
@ -193,7 +193,7 @@ class RelMultiHeadAttn(nn.Module):
|
|||
|
||||
def _rel_shift(self, x, zero_triu=False):
|
||||
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
||||
device=x.device, dtype=x.dtype)
|
||||
device=x.DEVICE, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||
|
||||
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
||||
|
|
@ -661,7 +661,7 @@ class MemTransformerLM(nn.Module):
|
|||
|
||||
hids = []
|
||||
if self.attn_type == 0: # default
|
||||
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
|
|
@ -691,7 +691,7 @@ class MemTransformerLM(nn.Module):
|
|||
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
||||
hids.append(core_out)
|
||||
elif self.attn_type == 2: # absolute
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ np.random.seed(args.seed)
|
|||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
if not args.cuda:
|
||||
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
|
||||
print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
|
||||
else:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class AdaptiveLogSoftmax(nn.Module):
|
|||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.device)
|
||||
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class LogUniformSampler(object):
|
|||
|
||||
with torch.no_grad():
|
||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||
device = labels.device
|
||||
device = labels.DEVICE
|
||||
neg_samples = neg_samples.to(device)
|
||||
true_log_probs = self.log_q[labels].to(device)
|
||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.device)
|
||||
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
|
|
|
|||
Reference in a new issue