chore(transformer-xl): Initial commit
This commit is contained in:
parent
ef4684ef39
commit
10512876f2
46 changed files with 10547 additions and 0 deletions
90
transformer-xl/pytorch/utils/adaptive_softmax.py
Normal file
90
transformer-xl/pytorch/utils/adaptive_softmax.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class AdaptiveLogSoftmax(nn.Module):
|
||||
def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
|
||||
super(AdaptiveLogSoftmax, self).__init__()
|
||||
|
||||
cutoffs = list(cutoffs)
|
||||
|
||||
if (cutoffs != sorted(cutoffs)) \
|
||||
or (min(cutoffs) <= 0) \
|
||||
or (max(cutoffs) >= (n_classes - 1)) \
|
||||
or (len(set(cutoffs)) != len(cutoffs)) \
|
||||
or any([int(c) != c for c in cutoffs]):
|
||||
|
||||
raise ValueError("cutoffs should be a sequence of unique, positive "
|
||||
"integers sorted in an increasing order, where "
|
||||
"each value is between 1 and n_classes-1")
|
||||
|
||||
self.in_features = in_features
|
||||
self.n_classes = n_classes
|
||||
self.cutoffs = cutoffs + [n_classes]
|
||||
|
||||
self.shortlist_size = self.cutoffs[0]
|
||||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
|
||||
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
|
||||
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
||||
|
||||
self.keep_order = keep_order
|
||||
|
||||
|
||||
def forward(self, hidden, target, weight, bias, keep_order=False):
|
||||
if hidden.size(0) != target.size(0):
|
||||
raise RuntimeError('Input and target should have the same size '
|
||||
'in the batch dimension.')
|
||||
|
||||
head_weight = torch.cat(
|
||||
[weight[:self.shortlist_size], self.cluster_weight], dim=0)
|
||||
head_bias = torch.cat(
|
||||
[bias[:self.shortlist_size], self.cluster_bias], dim=0)
|
||||
|
||||
head_logit = F.linear(hidden, head_weight, bias=head_bias)
|
||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.device)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
for i in range(len(cutoff_values) - 1):
|
||||
l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||
|
||||
mask_i = (target >= l_idx) & (target < h_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
target_i = target.index_select(0, indices_i) - l_idx
|
||||
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||
|
||||
if i == 0:
|
||||
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
else:
|
||||
weight_i = weight[l_idx:h_idx]
|
||||
bias_i = bias[l_idx:h_idx]
|
||||
|
||||
hidden_i = hidden.index_select(0, indices_i)
|
||||
|
||||
tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
|
||||
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||
|
||||
logprob_i = head_logprob_i[:, -i] \
|
||||
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
|
||||
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||
nll.index_copy_(0, indices_i, -logprob_i)
|
||||
else:
|
||||
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
||||
|
||||
offset += logprob_i.size(0)
|
||||
|
||||
return nll
|
||||
91
transformer-xl/pytorch/utils/data_parallel.py
Normal file
91
transformer-xl/pytorch/utils/data_parallel.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch
|
||||
from torch.nn.parallel._functions import Scatter
|
||||
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||
|
||||
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""
|
||||
Slices tensors into approximately equal chunks and
|
||||
distributes them across given GPUs. Duplicates
|
||||
references to objects that are not tensors.
|
||||
"""
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
try:
|
||||
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
|
||||
except:
|
||||
print('obj', obj.size())
|
||||
print('dim', dim)
|
||||
print('chunk_sizes', chunk_sizes)
|
||||
quit()
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return list(map(list, zip(*map(scatter_map, obj))))
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return [obj for targets in target_gpus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""Scatter with support for kwargs dictionary"""
|
||||
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
||||
|
||||
class BalancedDataParallel(DataParallel):
|
||||
def __init__(self, gpu0_bsz, *args, **kwargs):
|
||||
self.gpu0_bsz = gpu0_bsz
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
if self.gpu0_bsz == 0:
|
||||
device_ids = self.device_ids[1:]
|
||||
else:
|
||||
device_ids = self.device_ids
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids)
|
||||
if self.gpu0_bsz == 0:
|
||||
replicas = replicas[1:]
|
||||
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, device_ids)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
bsz = inputs[0].size(self.dim)
|
||||
num_dev = len(self.device_ids)
|
||||
gpu0_bsz = self.gpu0_bsz
|
||||
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
|
||||
if gpu0_bsz < bsz_unit:
|
||||
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
|
||||
delta = bsz - sum(chunk_sizes)
|
||||
for i in range(delta):
|
||||
chunk_sizes[i + 1] += 1
|
||||
if gpu0_bsz == 0:
|
||||
chunk_sizes = chunk_sizes[1:]
|
||||
else:
|
||||
return super().scatter(inputs, kwargs, device_ids)
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
|
||||
|
||||
40
transformer-xl/pytorch/utils/exp_utils.py
Normal file
40
transformer-xl/pytorch/utils/exp_utils.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import functools
|
||||
import os, shutil
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def logging(s, log_path, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(log_path, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
def get_logger(log_path, **kwargs):
|
||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||
|
||||
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
||||
if debug:
|
||||
print('Debug Mode : no experiment dir created')
|
||||
return functools.partial(logging, log_path=None, log_=False)
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
print('Experiment dir : {}'.format(dir_path))
|
||||
if scripts_to_save is not None:
|
||||
script_path = os.path.join(dir_path, 'scripts')
|
||||
if not os.path.exists(script_path):
|
||||
os.makedirs(script_path)
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
||||
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
||||
|
||||
def save_checkpoint(model, optimizer, path, epoch):
|
||||
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
|
||||
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
|
||||
147
transformer-xl/pytorch/utils/log_uniform_sampler.py
Normal file
147
transformer-xl/pytorch/utils/log_uniform_sampler.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
class LogUniformSampler(object):
|
||||
def __init__(self, range_max, n_sample):
|
||||
"""
|
||||
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
|
||||
expected count can be approximated by 1 - (1 - p)^n
|
||||
and we use a numerically stable version -expm1(num_tries * log1p(-p))
|
||||
|
||||
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.range_max = range_max
|
||||
log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
# print('P', self.dist.numpy().tolist()[-30:])
|
||||
|
||||
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
||||
|
||||
self.n_sample = n_sample
|
||||
|
||||
def sample(self, labels):
|
||||
"""
|
||||
labels: [b1, b2]
|
||||
Return
|
||||
true_log_probs: [b1, b2]
|
||||
samp_log_probs: [n_sample]
|
||||
neg_samples: [n_sample]
|
||||
"""
|
||||
|
||||
# neg_samples = torch.empty(0).long()
|
||||
n_sample = self.n_sample
|
||||
n_tries = 2 * n_sample
|
||||
|
||||
with torch.no_grad():
|
||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||
device = labels.device
|
||||
neg_samples = neg_samples.to(device)
|
||||
true_log_probs = self.log_q[labels].to(device)
|
||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||
return true_log_probs, samp_log_probs, neg_samples
|
||||
|
||||
def sample_logits(embedding, bias, labels, inputs, sampler):
|
||||
"""
|
||||
embedding: an nn.Embedding layer
|
||||
bias: [n_vocab]
|
||||
labels: [b1, b2]
|
||||
inputs: [b1, b2, n_emb]
|
||||
sampler: you may use a LogUniformSampler
|
||||
Return
|
||||
logits: [b1, b2, 1 + n_sample]
|
||||
"""
|
||||
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
|
||||
n_sample = neg_samples.size(0)
|
||||
b1, b2 = labels.size(0), labels.size(1)
|
||||
all_ids = torch.cat([labels.view(-1), neg_samples])
|
||||
all_w = embedding(all_ids)
|
||||
true_w = all_w[: -n_sample].view(b1, b2, -1)
|
||||
sample_w = all_w[- n_sample:].view(n_sample, -1)
|
||||
|
||||
all_b = bias[all_ids]
|
||||
true_b = all_b[: -n_sample].view(b1, b2)
|
||||
sample_b = all_b[- n_sample:]
|
||||
|
||||
hit = (labels[:, :, None] == neg_samples).detach()
|
||||
|
||||
true_logits = torch.einsum('ijk,ijk->ij',
|
||||
[true_w, inputs]) + true_b - true_log_probs
|
||||
sample_logits = torch.einsum('lk,ijk->ijl',
|
||||
[sample_w, inputs]) + sample_b - samp_log_probs
|
||||
sample_logits.masked_fill_(hit, -1e30)
|
||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# class LogUniformSampler(object):
|
||||
# def __init__(self, range_max, unique=False):
|
||||
# """
|
||||
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
# """
|
||||
# self.range_max = range_max
|
||||
# log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
|
||||
# self.unique = unique
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
|
||||
|
||||
# def sample(self, n_sample, labels):
|
||||
# pos_sample, new_labels = labels.unique(return_inverse=True)
|
||||
# n_pos_sample = pos_sample.size(0)
|
||||
# n_neg_sample = n_sample - n_pos_sample
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 1)
|
||||
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 0)
|
||||
# else:
|
||||
# sample_dist = self.dist
|
||||
|
||||
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
|
||||
|
||||
# sample = torch.cat([pos_sample, neg_sample])
|
||||
# sample_prob = self.dist[sample]
|
||||
|
||||
# return new_labels, sample, sample_prob
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
S, B = 3, 4
|
||||
n_vocab = 10000
|
||||
n_sample = 5
|
||||
H = 32
|
||||
|
||||
labels = torch.LongTensor(S, B).random_(0, n_vocab)
|
||||
|
||||
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||
|
||||
sampler = LogUniformSampler(n_vocab, unique=True)
|
||||
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||
|
||||
# print('true_probs', true_probs.numpy().tolist())
|
||||
# print('samp_probs', samp_probs.numpy().tolist())
|
||||
# print('neg_samples', neg_samples.numpy().tolist())
|
||||
|
||||
# print('sum', torch.sum(sampler.dist).item())
|
||||
|
||||
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
|
||||
|
||||
embedding = nn.Embedding(n_vocab, H)
|
||||
bias = torch.zeros(n_vocab)
|
||||
inputs = torch.Tensor(S, B, H).normal_()
|
||||
|
||||
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
|
||||
print('logits', logits.detach().numpy().tolist())
|
||||
print('logits shape', logits.size())
|
||||
print('out_labels', out_labels.detach().numpy().tolist())
|
||||
print('out_labels shape', out_labels.size())
|
||||
|
||||
151
transformer-xl/pytorch/utils/proj_adaptive_softmax.py
Normal file
151
transformer-xl/pytorch/utils/proj_adaptive_softmax.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
|
||||
CUDA_MINOR = int(torch.version.cuda.split('.')[1])
|
||||
|
||||
class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
keep_order=False):
|
||||
super(ProjectedAdaptiveLogSoftmax, self).__init__()
|
||||
|
||||
self.n_token = n_token
|
||||
self.d_embed = d_embed
|
||||
self.d_proj = d_proj
|
||||
|
||||
self.cutoffs = cutoffs + [n_token]
|
||||
self.cutoff_ends = [0] + self.cutoffs
|
||||
self.div_val = div_val
|
||||
|
||||
self.shortlist_size = self.cutoffs[0]
|
||||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
|
||||
if self.n_clusters > 0:
|
||||
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
|
||||
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
||||
|
||||
self.out_layers = nn.ModuleList()
|
||||
self.out_projs = nn.ParameterList()
|
||||
|
||||
if div_val == 1:
|
||||
for i in range(len(self.cutoffs)):
|
||||
if d_proj != d_embed:
|
||||
self.out_projs.append(
|
||||
nn.Parameter(torch.Tensor(d_proj, d_embed))
|
||||
)
|
||||
else:
|
||||
self.out_projs.append(None)
|
||||
|
||||
self.out_layers.append(nn.Linear(d_embed, n_token))
|
||||
else:
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
||||
d_emb_i = d_embed // (div_val ** i)
|
||||
|
||||
self.out_projs.append(
|
||||
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
|
||||
)
|
||||
|
||||
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
|
||||
|
||||
self.keep_order = keep_order
|
||||
|
||||
def _compute_logit(self, hidden, weight, bias, proj):
|
||||
if proj is None:
|
||||
logit = F.linear(hidden, weight, bias=bias)
|
||||
else:
|
||||
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
|
||||
proj_hid = F.linear(hidden, proj.t().contiguous())
|
||||
logit = F.linear(proj_hid, weight, bias=bias)
|
||||
# else:
|
||||
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
||||
# if bias is not None:
|
||||
# logit = logit + bias
|
||||
|
||||
return logit
|
||||
|
||||
def forward(self, hidden, target, keep_order=False):
|
||||
'''
|
||||
hidden :: [len*bsz x d_proj]
|
||||
target :: [len*bsz]
|
||||
'''
|
||||
|
||||
if hidden.size(0) != target.size(0):
|
||||
raise RuntimeError('Input and target should have the same size '
|
||||
'in the batch dimension.')
|
||||
|
||||
if self.n_clusters == 0:
|
||||
logit = self._compute_logit(hidden, self.out_layers[0].weight,
|
||||
self.out_layers[0].bias, self.out_projs[0])
|
||||
nll = -F.log_softmax(logit, dim=-1) \
|
||||
.gather(1, target.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# construct weights and biases
|
||||
weights, biases = [], []
|
||||
for i in range(len(self.cutoffs)):
|
||||
if self.div_val == 1:
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
weight_i = self.out_layers[0].weight[l_idx:r_idx]
|
||||
bias_i = self.out_layers[0].bias[l_idx:r_idx]
|
||||
else:
|
||||
weight_i = self.out_layers[i].weight
|
||||
bias_i = self.out_layers[i].bias
|
||||
|
||||
if i == 0:
|
||||
weight_i = torch.cat(
|
||||
[weight_i, self.cluster_weight], dim=0)
|
||||
bias_i = torch.cat(
|
||||
[bias_i, self.cluster_bias], dim=0)
|
||||
|
||||
weights.append(weight_i)
|
||||
biases.append(bias_i)
|
||||
|
||||
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
|
||||
|
||||
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
|
||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.device)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
for i in range(len(cutoff_values) - 1):
|
||||
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||
|
||||
mask_i = (target >= l_idx) & (target < r_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
target_i = target.index_select(0, indices_i) - l_idx
|
||||
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||
|
||||
if i == 0:
|
||||
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
else:
|
||||
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
|
||||
|
||||
hidden_i = hidden.index_select(0, indices_i)
|
||||
|
||||
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
|
||||
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||
|
||||
logprob_i = head_logprob_i[:, -i] \
|
||||
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
|
||||
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||
nll.index_copy_(0, indices_i, -logprob_i)
|
||||
else:
|
||||
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
||||
|
||||
offset += logprob_i.size(0)
|
||||
|
||||
return nll
|
||||
163
transformer-xl/pytorch/utils/vocabulary.py
Normal file
163
transformer-xl/pytorch/utils/vocabulary.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import os
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
delimiter=None, vocab_file=None):
|
||||
self.counter = Counter()
|
||||
self.special = special
|
||||
self.min_freq = min_freq
|
||||
self.max_size = max_size
|
||||
self.lower_case = lower_case
|
||||
self.delimiter = delimiter
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return symbols
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
|
||||
sents = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos)
|
||||
self.counter.update(symbols)
|
||||
sents.append(symbols)
|
||||
|
||||
return sents
|
||||
|
||||
def count_sents(self, sents, verbose=False):
|
||||
"""
|
||||
sents : a list of sentences, each a list of tokenized symbols
|
||||
"""
|
||||
if verbose: print('counting {} sents ...'.format(len(sents)))
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
self.counter.update(symbols)
|
||||
|
||||
def _build_from_file(self, vocab_file):
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
with open(vocab_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
symb = line.strip().split()[0]
|
||||
self.add_symbol(symb)
|
||||
self.unk_idx = self.sym2idx['<UNK>']
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
self._build_from_file(self.vocab_file)
|
||||
print('final vocab size {}'.format(len(self)))
|
||||
else:
|
||||
print('building vocab with min_freq={}, max_size={}'.format(
|
||||
self.min_freq, self.max_size))
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
for sym in self.special:
|
||||
self.add_special(sym)
|
||||
|
||||
for sym, cnt in self.counter.most_common(self.max_size):
|
||||
if cnt < self.min_freq: break
|
||||
self.add_symbol(sym)
|
||||
|
||||
print('final vocab size {} from {} unique tokens'.format(
|
||||
len(self), len(self.counter)))
|
||||
|
||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
|
||||
add_double_eos=False):
|
||||
if verbose: print('encoding file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
encoded = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos,
|
||||
add_double_eos=add_double_eos)
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = torch.cat(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||
if verbose: print('encoding {} sents ...'.format(len(sents)))
|
||||
encoded = []
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = torch.cat(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def add_special(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
|
||||
|
||||
def add_symbol(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def get_sym(self, idx):
|
||||
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
|
||||
return self.idx2sym[idx]
|
||||
|
||||
def get_idx(self, sym):
|
||||
if sym in self.sym2idx:
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
# print('encounter unk {}'.format(sym))
|
||||
assert '<eos>' not in sym
|
||||
assert hasattr(self, 'unk_idx')
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
|
||||
def get_symbols(self, indices):
|
||||
return [self.get_sym(idx) for idx in indices]
|
||||
|
||||
def get_indices(self, symbols):
|
||||
return [self.get_idx(sym) for sym in symbols]
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.get_indices(symbols))
|
||||
|
||||
def convert_to_sent(self, indices, exclude=None):
|
||||
if exclude is None:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices])
|
||||
else:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
Reference in a new issue