feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -50,7 +50,7 @@ class AdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0
cutoff_values = [0] + self.cutoffs

View file

@ -38,7 +38,7 @@ class LogUniformSampler(object):
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
device = labels.DEVICE
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)

View file

@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0
cutoff_values = [0] + self.cutoffs