151 lines
5.6 KiB
Python
151 lines
5.6 KiB
Python
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
|
|
CUDA_MINOR = int(torch.version.cuda.split('.')[1])
|
|
|
|
class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
|
keep_order=False):
|
|
super(ProjectedAdaptiveLogSoftmax, self).__init__()
|
|
|
|
self.n_token = n_token
|
|
self.d_embed = d_embed
|
|
self.d_proj = d_proj
|
|
|
|
self.cutoffs = cutoffs + [n_token]
|
|
self.cutoff_ends = [0] + self.cutoffs
|
|
self.div_val = div_val
|
|
|
|
self.shortlist_size = self.cutoffs[0]
|
|
self.n_clusters = len(self.cutoffs) - 1
|
|
self.head_size = self.shortlist_size + self.n_clusters
|
|
|
|
if self.n_clusters > 0:
|
|
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
|
|
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
|
|
|
self.out_layers = nn.ModuleList()
|
|
self.out_projs = nn.ParameterList()
|
|
|
|
if div_val == 1:
|
|
for i in range(len(self.cutoffs)):
|
|
if d_proj != d_embed:
|
|
self.out_projs.append(
|
|
nn.Parameter(torch.Tensor(d_proj, d_embed))
|
|
)
|
|
else:
|
|
self.out_projs.append(None)
|
|
|
|
self.out_layers.append(nn.Linear(d_embed, n_token))
|
|
else:
|
|
for i in range(len(self.cutoffs)):
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
|
d_emb_i = d_embed // (div_val ** i)
|
|
|
|
self.out_projs.append(
|
|
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
|
|
)
|
|
|
|
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
|
|
|
|
self.keep_order = keep_order
|
|
|
|
def _compute_logit(self, hidden, weight, bias, proj):
|
|
if proj is None:
|
|
logit = F.linear(hidden, weight, bias=bias)
|
|
else:
|
|
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
|
|
proj_hid = F.linear(hidden, proj.t().contiguous())
|
|
logit = F.linear(proj_hid, weight, bias=bias)
|
|
# else:
|
|
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
|
# if bias is not None:
|
|
# logit = logit + bias
|
|
|
|
return logit
|
|
|
|
def forward(self, hidden, target, keep_order=False):
|
|
'''
|
|
hidden :: [len*bsz x d_proj]
|
|
target :: [len*bsz]
|
|
'''
|
|
|
|
if hidden.size(0) != target.size(0):
|
|
raise RuntimeError('Input and target should have the same size '
|
|
'in the batch dimension.')
|
|
|
|
if self.n_clusters == 0:
|
|
logit = self._compute_logit(hidden, self.out_layers[0].weight,
|
|
self.out_layers[0].bias, self.out_projs[0])
|
|
nll = -F.log_softmax(logit, dim=-1) \
|
|
.gather(1, target.unsqueeze(1)).squeeze(1)
|
|
else:
|
|
# construct weights and biases
|
|
weights, biases = [], []
|
|
for i in range(len(self.cutoffs)):
|
|
if self.div_val == 1:
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
|
weight_i = self.out_layers[0].weight[l_idx:r_idx]
|
|
bias_i = self.out_layers[0].bias[l_idx:r_idx]
|
|
else:
|
|
weight_i = self.out_layers[i].weight
|
|
bias_i = self.out_layers[i].bias
|
|
|
|
if i == 0:
|
|
weight_i = torch.cat(
|
|
[weight_i, self.cluster_weight], dim=0)
|
|
bias_i = torch.cat(
|
|
[bias_i, self.cluster_bias], dim=0)
|
|
|
|
weights.append(weight_i)
|
|
biases.append(bias_i)
|
|
|
|
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
|
|
|
|
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
|
|
head_logprob = F.log_softmax(head_logit, dim=1)
|
|
|
|
nll = torch.zeros_like(target,
|
|
dtype=hidden.dtype, device=hidden.DEVICE)
|
|
|
|
offset = 0
|
|
cutoff_values = [0] + self.cutoffs
|
|
for i in range(len(cutoff_values) - 1):
|
|
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
|
|
|
|
mask_i = (target >= l_idx) & (target < r_idx)
|
|
indices_i = mask_i.nonzero().squeeze()
|
|
|
|
if indices_i.numel() == 0:
|
|
continue
|
|
|
|
target_i = target.index_select(0, indices_i) - l_idx
|
|
head_logprob_i = head_logprob.index_select(0, indices_i)
|
|
|
|
if i == 0:
|
|
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
|
else:
|
|
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
|
|
|
|
hidden_i = hidden.index_select(0, indices_i)
|
|
|
|
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
|
|
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
|
|
|
logprob_i = head_logprob_i[:, -i] \
|
|
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
|
|
|
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
|
nll.index_copy_(0, indices_i, -logprob_i)
|
|
else:
|
|
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
|
|
|
offset += logprob_i.size(0)
|
|
|
|
return nll
|