feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -176,9 +176,9 @@ class RelMultiHeadAttn(nn.Module):
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
device=x.DEVICE, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
if left:
mask = mask.flip(1)
@ -193,7 +193,7 @@ class RelMultiHeadAttn(nn.Module):
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
device=x.DEVICE, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
@ -661,7 +661,7 @@ class MemTransformerLM(nn.Module):
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
@ -691,7 +691,7 @@ class MemTransformerLM(nn.Module):
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)