chore(transformer-xl): Initial commit
This commit is contained in:
parent
ef4684ef39
commit
10512876f2
46 changed files with 10547 additions and 0 deletions
546
transformer-xl/tf/model.py
Normal file
546
transformer-xl/tf/model.py
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def positional_embedding(pos_seq, inv_freq, bsz=None):
|
||||
sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
|
||||
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
||||
if bsz is not None:
|
||||
return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
|
||||
else:
|
||||
return pos_emb[:, None, :]
|
||||
|
||||
|
||||
def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer,
|
||||
scope='ff', is_training=True):
|
||||
output = inp
|
||||
with tf.variable_scope(scope):
|
||||
output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu,
|
||||
kernel_initializer=kernel_initializer,
|
||||
name='layer_1')
|
||||
output = tf.layers.dropout(output, dropout, training=is_training,
|
||||
name='drop_1')
|
||||
output = tf.layers.dense(output, d_model,
|
||||
kernel_initializer=kernel_initializer,
|
||||
name='layer_2')
|
||||
output = tf.layers.dropout(output, dropout, training=is_training,
|
||||
name='drop_2')
|
||||
output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1)
|
||||
return output
|
||||
|
||||
|
||||
def rel_shift(x):
|
||||
x_size = tf.shape(x)
|
||||
|
||||
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
|
||||
x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
|
||||
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
|
||||
x = tf.reshape(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
|
||||
n_head, d_head, dropout, dropatt, is_training,
|
||||
kernel_initializer, scope='rel_attn'):
|
||||
scale = 1 / (d_head ** 0.5)
|
||||
with tf.variable_scope(scope):
|
||||
qlen = tf.shape(w)[0]
|
||||
rlen = tf.shape(r)[0]
|
||||
bsz = tf.shape(w)[1]
|
||||
|
||||
cat = tf.concat([mems, w],
|
||||
0) if mems is not None and mems.shape.ndims > 1 else w
|
||||
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='qkv')
|
||||
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='r')
|
||||
|
||||
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
|
||||
w_head_q = w_head_q[-qlen:]
|
||||
|
||||
klen = tf.shape(w_head_k)[0]
|
||||
|
||||
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
|
||||
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
|
||||
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
|
||||
|
||||
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
|
||||
|
||||
rw_head_q = w_head_q + r_w_bias
|
||||
rr_head_q = w_head_q + r_r_bias
|
||||
|
||||
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
|
||||
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
|
||||
BD = rel_shift(BD)
|
||||
|
||||
attn_score = (AC + BD) * scale
|
||||
attn_mask_t = attn_mask[:, :, None, None]
|
||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
||||
|
||||
attn_prob = tf.nn.softmax(attn_score, 1)
|
||||
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
|
||||
|
||||
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
|
||||
size_t = tf.shape(attn_vec)
|
||||
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
|
||||
|
||||
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='o')
|
||||
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
|
||||
|
||||
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
|
||||
return output
|
||||
|
||||
|
||||
def embedding_lookup(lookup_table, x, use_tpu=True):
|
||||
if use_tpu:
|
||||
n_token = tf.shape(lookup_table)[0]
|
||||
one_hot_idx = tf.one_hot(x, n_token)
|
||||
if one_hot_idx.shape.ndims == 2:
|
||||
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
|
||||
else:
|
||||
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
|
||||
else:
|
||||
return tf.nn.embedding_lookup(lookup_table, x)
|
||||
|
||||
|
||||
def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
||||
proj_initializer, div_val=1,
|
||||
proj_same_dim=True,
|
||||
scope='adaptive_embed', **kwargs):
|
||||
emb_scale = d_proj ** 0.5
|
||||
with tf.variable_scope(scope):
|
||||
if div_val == 1:
|
||||
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
||||
initializer=initializer)
|
||||
y = embedding_lookup(lookup_table, x, use_tpu=False)
|
||||
if d_proj != d_embed:
|
||||
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
||||
else:
|
||||
proj_W = None
|
||||
ret_params = [lookup_table, proj_W]
|
||||
else:
|
||||
tables, projs = [], []
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
x_size = tf.shape(x)
|
||||
y = tf.zeros([x_size[0], x_size[1], d_proj])
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
mask = (x >= l_idx) & (x < r_idx)
|
||||
cur_x = tf.boolean_mask(x, mask) - l_idx
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
lookup_table = tf.get_variable('lookup_table',
|
||||
[r_idx - l_idx, cur_d_embed],
|
||||
initializer=initializer)
|
||||
cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False)
|
||||
if d_proj == cur_d_embed and not proj_same_dim:
|
||||
proj_W = None
|
||||
else:
|
||||
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
cur_y = tf.einsum('id,de->ie', cur_y, proj_W)
|
||||
mask_idx = tf.to_int64(tf.where(mask))
|
||||
y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y)))
|
||||
tables.append(lookup_table)
|
||||
projs.append(proj_W)
|
||||
ret_params = [tables, projs]
|
||||
|
||||
y *= emb_scale
|
||||
return y, ret_params
|
||||
|
||||
|
||||
def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
||||
proj_initializer, div_val=1, perms=None,
|
||||
proj_same_dim=True,
|
||||
scope='adaptive_embed'):
|
||||
"""
|
||||
perms: If None, first compute W = W1 x W2 (projection for each bin),
|
||||
and then compute X x W (embedding lookup). If not None,
|
||||
use bin-based embedding lookup with max_bin_size defined by
|
||||
the shape of perms.
|
||||
"""
|
||||
emb_scale = d_proj ** 0.5
|
||||
with tf.variable_scope(scope):
|
||||
if div_val == 1:
|
||||
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
||||
initializer=initializer)
|
||||
y = embedding_lookup(lookup_table, x)
|
||||
if d_proj != d_embed:
|
||||
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
||||
else:
|
||||
proj_W = None
|
||||
ret_params = [lookup_table, proj_W]
|
||||
else:
|
||||
tables, projs = [], []
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
x_size = tf.shape(x)
|
||||
if perms is None:
|
||||
cat_lookup = []
|
||||
else:
|
||||
cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj])
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
lookup_table = tf.get_variable('lookup_table',
|
||||
[r_idx - l_idx, cur_d_embed],
|
||||
initializer=initializer)
|
||||
if cur_d_embed == d_proj and not proj_same_dim:
|
||||
proj_W = None
|
||||
else:
|
||||
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
if perms is None:
|
||||
cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W))
|
||||
else:
|
||||
# speed up the computation of the first bin
|
||||
# also save some meory
|
||||
if i == 0:
|
||||
cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1))
|
||||
if proj_W is not None:
|
||||
cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W)
|
||||
cur_y *= perms[i][:, :, None]
|
||||
cat_lookup += cur_y
|
||||
else:
|
||||
cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i])
|
||||
cur_x = tf.to_int32(cur_x)
|
||||
cur_y = embedding_lookup(lookup_table, cur_x)
|
||||
if proj_W is not None:
|
||||
cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W)
|
||||
cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i])
|
||||
tables.append(lookup_table)
|
||||
projs.append(proj_W)
|
||||
if perms is None:
|
||||
cat_lookup = tf.concat(cat_lookup, 0)
|
||||
y = embedding_lookup(cat_lookup, x)
|
||||
else:
|
||||
y = cat_lookup
|
||||
ret_params = [tables, projs]
|
||||
|
||||
y *= emb_scale
|
||||
return y, ret_params
|
||||
|
||||
|
||||
def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
||||
params, tie_projs,
|
||||
initializer=None, proj_initializer=None,
|
||||
div_val=1, scope='adaptive_softmax',
|
||||
proj_same_dim=True,
|
||||
return_mean=True, **kwargs):
|
||||
def _logit(x, W, b, proj):
|
||||
y = x
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
|
||||
params_W, params_projs = params[0], params[1]
|
||||
|
||||
def _gather_logprob(logprob, target):
|
||||
lp_size = tf.shape(logprob)
|
||||
r = tf.range(lp_size[0])
|
||||
idx = tf.stack([r, target], 1)
|
||||
return tf.gather_nd(logprob, idx)
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
if len(cutoffs) == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token],
|
||||
initializer=tf.zeros_initializer())
|
||||
output = _logit(hidden, params_W, softmax_b, params_projs)
|
||||
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
||||
logits=output)
|
||||
else:
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
nll = tf.zeros_like(target, dtype=tf.float32)
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
mask = (target >= l_idx) & (target < r_idx)
|
||||
mask_idx = tf.where(mask)
|
||||
cur_target = tf.boolean_mask(target, mask) - l_idx
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
|
||||
if div_val == 1:
|
||||
cur_W = params_W[l_idx: r_idx]
|
||||
else:
|
||||
cur_W = params_W[i]
|
||||
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
||||
initializer=tf.zeros_initializer())
|
||||
if tie_projs[i]:
|
||||
if div_val == 1:
|
||||
cur_proj = params_projs
|
||||
else:
|
||||
cur_proj = params_projs[i]
|
||||
else:
|
||||
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
||||
cur_proj = None
|
||||
else:
|
||||
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
if i == 0:
|
||||
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
||||
initializer=tf.zeros_initializer())
|
||||
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
||||
initializer=tf.zeros_initializer())
|
||||
cur_W = tf.concat([cur_W, cluster_W], 0)
|
||||
cur_b = tf.concat([cur_b, cluster_b], 0)
|
||||
|
||||
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
||||
head_logprob = tf.nn.log_softmax(head_logit)
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_logprob = _gather_logprob(cur_head_logprob, cur_target)
|
||||
else:
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_hidden = tf.boolean_mask(hidden, mask)
|
||||
tail_logit = tf.squeeze(_logit(
|
||||
cur_hidden[None], cur_W, cur_b, cur_proj), 0)
|
||||
tail_logprob = tf.nn.log_softmax(tail_logit)
|
||||
cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] +
|
||||
_gather_logprob(tail_logprob, cur_target))
|
||||
nll += tf.scatter_nd(mask_idx, -cur_logprob,
|
||||
tf.to_int64(tf.shape(nll)))
|
||||
if return_mean:
|
||||
nll = tf.reduce_mean(nll)
|
||||
return nll
|
||||
|
||||
|
||||
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
||||
params, tie_projs,
|
||||
initializer=None, proj_initializer=None,
|
||||
div_val=1, perms=None, proj_same_dim=True,
|
||||
scope='adaptive_softmax',
|
||||
**kwargs):
|
||||
def _logit(x, W, b, proj):
|
||||
y = x
|
||||
if x.shape.ndims == 3:
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
else:
|
||||
if proj is not None:
|
||||
y = tf.einsum('id,ed->ie', y, proj)
|
||||
return tf.einsum('id,nd->in', y, W) + b
|
||||
|
||||
params_W, params_projs = params[0], params[1]
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
if len(cutoffs) == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token],
|
||||
initializer=tf.zeros_initializer())
|
||||
output = _logit(hidden, params_W, softmax_b, params_projs)
|
||||
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
||||
logits=output)
|
||||
nll = tf.reduce_mean(nll)
|
||||
else:
|
||||
total_loss, total_cnt = 0, 0
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
|
||||
if div_val == 1:
|
||||
cur_W = params_W[l_idx: r_idx]
|
||||
else:
|
||||
cur_W = params_W[i]
|
||||
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
||||
initializer=tf.zeros_initializer())
|
||||
if tie_projs[i]:
|
||||
if div_val == 1:
|
||||
cur_proj = params_projs
|
||||
else:
|
||||
cur_proj = params_projs[i]
|
||||
else:
|
||||
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
||||
cur_proj = None
|
||||
else:
|
||||
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
|
||||
if i == 0:
|
||||
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
||||
initializer=tf.zeros_initializer())
|
||||
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
||||
initializer=tf.zeros_initializer())
|
||||
cur_W = tf.concat([cur_W, cluster_W], 0)
|
||||
cur_b = tf.concat([cur_b, cluster_b], 0)
|
||||
|
||||
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
head_target = kwargs.get("head_target")
|
||||
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=head_target,
|
||||
logits=head_logit)
|
||||
|
||||
masked_loss = head_nll * perms[i]
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(perms[i])
|
||||
|
||||
# head_logprob = tf.nn.log_softmax(head_logit)
|
||||
|
||||
# final_logprob = head_logprob * perms[i][:, :, None]
|
||||
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
|
||||
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
|
||||
# total_cnt += tf.reduce_sum(perms[i])
|
||||
else:
|
||||
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
|
||||
|
||||
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
|
||||
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
|
||||
perms[i])
|
||||
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=tf.to_int32(tail_target),
|
||||
logits=tail_logit)
|
||||
|
||||
sum_nll = cur_head_nll + tail_nll
|
||||
mask = tf.reduce_sum(perms[i], [0, 1])
|
||||
|
||||
masked_loss = sum_nll * mask
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(mask)
|
||||
|
||||
nll = total_loss / total_cnt
|
||||
|
||||
return nll
|
||||
|
||||
|
||||
def _create_mask(qlen, mlen, same_length=False):
|
||||
attn_mask = tf.ones([qlen, qlen])
|
||||
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
|
||||
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
|
||||
attn_mask_pad = tf.zeros([qlen, mlen])
|
||||
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
|
||||
if same_length:
|
||||
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
|
||||
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
|
||||
return ret
|
||||
|
||||
def _cache_mem(curr_out, prev_mem, mem_len=None):
|
||||
if mem_len is None or prev_mem is None:
|
||||
new_mem = curr_out
|
||||
elif mem_len == 0:
|
||||
return prev_mem
|
||||
else:
|
||||
new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:]
|
||||
|
||||
return tf.stop_gradient(new_mem)
|
||||
|
||||
|
||||
def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed,
|
||||
n_head, d_head, d_inner, dropout, dropatt,
|
||||
initializer, is_training, proj_initializer=None,
|
||||
mem_len=None, cutoffs=[], div_val=1, tie_projs=[],
|
||||
same_length=False, clamp_len=-1, use_tpu=True,
|
||||
input_perms=None, target_perms=None, head_target=None,
|
||||
untie_r=False, proj_same_dim=True,
|
||||
scope='transformer'):
|
||||
"""
|
||||
cutoffs: a list of python int. Cutoffs for adaptive softmax.
|
||||
tie_projs: a list of python bools. Whether to tie the projections.
|
||||
use_tpu: if True, use one_hot in embedding lookup and bin-based implementation
|
||||
of adaptive softmax.
|
||||
perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
|
||||
Only used in the adaptive setting.
|
||||
"""
|
||||
new_mems = []
|
||||
with tf.variable_scope(scope):
|
||||
if untie_r:
|
||||
r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
|
||||
initializer=initializer)
|
||||
r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
|
||||
initializer=initializer)
|
||||
else:
|
||||
r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
|
||||
initializer=initializer)
|
||||
r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
|
||||
initializer=initializer)
|
||||
|
||||
qlen = tf.shape(dec_inp)[0]
|
||||
mlen = tf.shape(mems[0])[0] if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
if proj_initializer is None:
|
||||
proj_initializer = initializer
|
||||
lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else
|
||||
mask_adaptive_embedding_lookup)
|
||||
embeddings, shared_params = lookup_fn(
|
||||
x=dec_inp,
|
||||
n_token=n_token,
|
||||
d_embed=d_embed,
|
||||
d_proj=d_model,
|
||||
cutoffs=cutoffs,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
div_val= div_val,
|
||||
perms=input_perms,
|
||||
proj_same_dim=proj_same_dim)
|
||||
|
||||
attn_mask = _create_mask(qlen, mlen, same_length)
|
||||
|
||||
pos_seq = tf.range(klen - 1, -1, -1.0)
|
||||
if clamp_len > 0:
|
||||
pos_seq = tf.minimum(pos_seq, clamp_len)
|
||||
inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
|
||||
pos_emb = positional_embedding(pos_seq, inv_freq)
|
||||
|
||||
output = tf.layers.dropout(embeddings, dropout, training=is_training)
|
||||
pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
|
||||
|
||||
if mems is None:
|
||||
mems = [None] * n_layer
|
||||
|
||||
for i in range(n_layer):
|
||||
# cache new mems
|
||||
new_mems.append(_cache_mem(output, mems[i], mem_len))
|
||||
|
||||
with tf.variable_scope('layer_{}'.format(i)):
|
||||
output = rel_multihead_attn(
|
||||
w=output,
|
||||
r=pos_emb,
|
||||
r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
|
||||
r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
|
||||
attn_mask=attn_mask,
|
||||
mems=mems[i],
|
||||
d_model=d_model,
|
||||
n_head=n_head,
|
||||
d_head=d_head,
|
||||
dropout=dropout,
|
||||
dropatt=dropatt,
|
||||
is_training=is_training,
|
||||
kernel_initializer=initializer)
|
||||
output = positionwise_FF(
|
||||
inp=output,
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
dropout=dropout,
|
||||
kernel_initializer=initializer,
|
||||
is_training=is_training)
|
||||
|
||||
output = tf.layers.dropout(output, dropout, training=is_training)
|
||||
|
||||
logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else
|
||||
mask_adaptive_logsoftmax)
|
||||
loss = logsoftmax_fn(
|
||||
hidden=output,
|
||||
target=target,
|
||||
n_token=n_token,
|
||||
d_embed=d_embed,
|
||||
d_proj=d_model,
|
||||
cutoffs=cutoffs,
|
||||
params=shared_params,
|
||||
tie_projs=tie_projs,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
div_val=div_val,
|
||||
perms=target_perms,
|
||||
head_target=head_target,
|
||||
proj_same_dim=proj_same_dim)
|
||||
return loss, new_mems
|
||||
|
||||
Reference in a new issue