546 lines
21 KiB
Python
546 lines
21 KiB
Python
import tensorflow as tf
|
|
|
|
|
|
def positional_embedding(pos_seq, inv_freq, bsz=None):
|
|
sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
|
|
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
|
if bsz is not None:
|
|
return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
|
|
else:
|
|
return pos_emb[:, None, :]
|
|
|
|
|
|
def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer,
|
|
scope='ff', is_training=True):
|
|
output = inp
|
|
with tf.variable_scope(scope):
|
|
output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu,
|
|
kernel_initializer=kernel_initializer,
|
|
name='layer_1')
|
|
output = tf.layers.dropout(output, dropout, training=is_training,
|
|
name='drop_1')
|
|
output = tf.layers.dense(output, d_model,
|
|
kernel_initializer=kernel_initializer,
|
|
name='layer_2')
|
|
output = tf.layers.dropout(output, dropout, training=is_training,
|
|
name='drop_2')
|
|
output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1)
|
|
return output
|
|
|
|
|
|
def rel_shift(x):
|
|
x_size = tf.shape(x)
|
|
|
|
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
|
|
x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
|
|
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
|
|
x = tf.reshape(x, x_size)
|
|
|
|
return x
|
|
|
|
|
|
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
|
|
n_head, d_head, dropout, dropatt, is_training,
|
|
kernel_initializer, scope='rel_attn'):
|
|
scale = 1 / (d_head ** 0.5)
|
|
with tf.variable_scope(scope):
|
|
qlen = tf.shape(w)[0]
|
|
rlen = tf.shape(r)[0]
|
|
bsz = tf.shape(w)[1]
|
|
|
|
cat = tf.concat([mems, w],
|
|
0) if mems is not None and mems.shape.ndims > 1 else w
|
|
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
|
|
kernel_initializer=kernel_initializer, name='qkv')
|
|
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
|
|
kernel_initializer=kernel_initializer, name='r')
|
|
|
|
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
|
|
w_head_q = w_head_q[-qlen:]
|
|
|
|
klen = tf.shape(w_head_k)[0]
|
|
|
|
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
|
|
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
|
|
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
|
|
|
|
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
|
|
|
|
rw_head_q = w_head_q + r_w_bias
|
|
rr_head_q = w_head_q + r_r_bias
|
|
|
|
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
|
|
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
|
|
BD = rel_shift(BD)
|
|
|
|
attn_score = (AC + BD) * scale
|
|
attn_mask_t = attn_mask[:, :, None, None]
|
|
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
|
|
|
attn_prob = tf.nn.softmax(attn_score, 1)
|
|
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
|
|
|
|
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
|
|
size_t = tf.shape(attn_vec)
|
|
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
|
|
|
|
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
|
|
kernel_initializer=kernel_initializer, name='o')
|
|
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
|
|
|
|
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
|
|
return output
|
|
|
|
|
|
def embedding_lookup(lookup_table, x, use_tpu=True):
|
|
if use_tpu:
|
|
n_token = tf.shape(lookup_table)[0]
|
|
one_hot_idx = tf.one_hot(x, n_token)
|
|
if one_hot_idx.shape.ndims == 2:
|
|
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
|
|
else:
|
|
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
|
|
else:
|
|
return tf.nn.embedding_lookup(lookup_table, x)
|
|
|
|
|
|
def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
|
proj_initializer, div_val=1,
|
|
proj_same_dim=True,
|
|
scope='adaptive_embed', **kwargs):
|
|
emb_scale = d_proj ** 0.5
|
|
with tf.variable_scope(scope):
|
|
if div_val == 1:
|
|
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
|
initializer=initializer)
|
|
y = embedding_lookup(lookup_table, x, use_tpu=False)
|
|
if d_proj != d_embed:
|
|
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
|
else:
|
|
proj_W = None
|
|
ret_params = [lookup_table, proj_W]
|
|
else:
|
|
tables, projs = [], []
|
|
cutoff_ends = [0] + cutoffs + [n_token]
|
|
x_size = tf.shape(x)
|
|
y = tf.zeros([x_size[0], x_size[1], d_proj])
|
|
for i in range(len(cutoff_ends) - 1):
|
|
with tf.variable_scope('cutoff_{}'.format(i)):
|
|
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
|
mask = (x >= l_idx) & (x < r_idx)
|
|
cur_x = tf.boolean_mask(x, mask) - l_idx
|
|
cur_d_embed = d_embed // (div_val ** i)
|
|
lookup_table = tf.get_variable('lookup_table',
|
|
[r_idx - l_idx, cur_d_embed],
|
|
initializer=initializer)
|
|
cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False)
|
|
if d_proj == cur_d_embed and not proj_same_dim:
|
|
proj_W = None
|
|
else:
|
|
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
cur_y = tf.einsum('id,de->ie', cur_y, proj_W)
|
|
mask_idx = tf.to_int64(tf.where(mask))
|
|
y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y)))
|
|
tables.append(lookup_table)
|
|
projs.append(proj_W)
|
|
ret_params = [tables, projs]
|
|
|
|
y *= emb_scale
|
|
return y, ret_params
|
|
|
|
|
|
def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
|
proj_initializer, div_val=1, perms=None,
|
|
proj_same_dim=True,
|
|
scope='adaptive_embed'):
|
|
"""
|
|
perms: If None, first compute W = W1 x W2 (projection for each bin),
|
|
and then compute X x W (embedding lookup). If not None,
|
|
use bin-based embedding lookup with max_bin_size defined by
|
|
the shape of perms.
|
|
"""
|
|
emb_scale = d_proj ** 0.5
|
|
with tf.variable_scope(scope):
|
|
if div_val == 1:
|
|
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
|
initializer=initializer)
|
|
y = embedding_lookup(lookup_table, x)
|
|
if d_proj != d_embed:
|
|
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
|
else:
|
|
proj_W = None
|
|
ret_params = [lookup_table, proj_W]
|
|
else:
|
|
tables, projs = [], []
|
|
cutoff_ends = [0] + cutoffs + [n_token]
|
|
x_size = tf.shape(x)
|
|
if perms is None:
|
|
cat_lookup = []
|
|
else:
|
|
cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj])
|
|
for i in range(len(cutoff_ends) - 1):
|
|
with tf.variable_scope('cutoff_{}'.format(i)):
|
|
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
|
cur_d_embed = d_embed // (div_val ** i)
|
|
lookup_table = tf.get_variable('lookup_table',
|
|
[r_idx - l_idx, cur_d_embed],
|
|
initializer=initializer)
|
|
if cur_d_embed == d_proj and not proj_same_dim:
|
|
proj_W = None
|
|
else:
|
|
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
if perms is None:
|
|
cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W))
|
|
else:
|
|
# speed up the computation of the first bin
|
|
# also save some meory
|
|
if i == 0:
|
|
cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1))
|
|
if proj_W is not None:
|
|
cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W)
|
|
cur_y *= perms[i][:, :, None]
|
|
cat_lookup += cur_y
|
|
else:
|
|
cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i])
|
|
cur_x = tf.to_int32(cur_x)
|
|
cur_y = embedding_lookup(lookup_table, cur_x)
|
|
if proj_W is not None:
|
|
cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W)
|
|
cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i])
|
|
tables.append(lookup_table)
|
|
projs.append(proj_W)
|
|
if perms is None:
|
|
cat_lookup = tf.concat(cat_lookup, 0)
|
|
y = embedding_lookup(cat_lookup, x)
|
|
else:
|
|
y = cat_lookup
|
|
ret_params = [tables, projs]
|
|
|
|
y *= emb_scale
|
|
return y, ret_params
|
|
|
|
|
|
def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
|
params, tie_projs,
|
|
initializer=None, proj_initializer=None,
|
|
div_val=1, scope='adaptive_softmax',
|
|
proj_same_dim=True,
|
|
return_mean=True, **kwargs):
|
|
def _logit(x, W, b, proj):
|
|
y = x
|
|
if proj is not None:
|
|
y = tf.einsum('ibd,ed->ibe', y, proj)
|
|
return tf.einsum('ibd,nd->ibn', y, W) + b
|
|
|
|
params_W, params_projs = params[0], params[1]
|
|
|
|
def _gather_logprob(logprob, target):
|
|
lp_size = tf.shape(logprob)
|
|
r = tf.range(lp_size[0])
|
|
idx = tf.stack([r, target], 1)
|
|
return tf.gather_nd(logprob, idx)
|
|
|
|
with tf.variable_scope(scope):
|
|
if len(cutoffs) == 0:
|
|
softmax_b = tf.get_variable('bias', [n_token],
|
|
initializer=tf.zeros_initializer())
|
|
output = _logit(hidden, params_W, softmax_b, params_projs)
|
|
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
|
logits=output)
|
|
else:
|
|
cutoff_ends = [0] + cutoffs + [n_token]
|
|
nll = tf.zeros_like(target, dtype=tf.float32)
|
|
for i in range(len(cutoff_ends) - 1):
|
|
with tf.variable_scope('cutoff_{}'.format(i)):
|
|
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
|
mask = (target >= l_idx) & (target < r_idx)
|
|
mask_idx = tf.where(mask)
|
|
cur_target = tf.boolean_mask(target, mask) - l_idx
|
|
cur_d_embed = d_embed // (div_val ** i)
|
|
|
|
if div_val == 1:
|
|
cur_W = params_W[l_idx: r_idx]
|
|
else:
|
|
cur_W = params_W[i]
|
|
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
|
initializer=tf.zeros_initializer())
|
|
if tie_projs[i]:
|
|
if div_val == 1:
|
|
cur_proj = params_projs
|
|
else:
|
|
cur_proj = params_projs[i]
|
|
else:
|
|
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
|
cur_proj = None
|
|
else:
|
|
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
if i == 0:
|
|
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
|
initializer=tf.zeros_initializer())
|
|
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
|
initializer=tf.zeros_initializer())
|
|
cur_W = tf.concat([cur_W, cluster_W], 0)
|
|
cur_b = tf.concat([cur_b, cluster_b], 0)
|
|
|
|
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
|
head_logprob = tf.nn.log_softmax(head_logit)
|
|
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
|
cur_logprob = _gather_logprob(cur_head_logprob, cur_target)
|
|
else:
|
|
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
|
cur_hidden = tf.boolean_mask(hidden, mask)
|
|
tail_logit = tf.squeeze(_logit(
|
|
cur_hidden[None], cur_W, cur_b, cur_proj), 0)
|
|
tail_logprob = tf.nn.log_softmax(tail_logit)
|
|
cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] +
|
|
_gather_logprob(tail_logprob, cur_target))
|
|
nll += tf.scatter_nd(mask_idx, -cur_logprob,
|
|
tf.to_int64(tf.shape(nll)))
|
|
if return_mean:
|
|
nll = tf.reduce_mean(nll)
|
|
return nll
|
|
|
|
|
|
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
|
params, tie_projs,
|
|
initializer=None, proj_initializer=None,
|
|
div_val=1, perms=None, proj_same_dim=True,
|
|
scope='adaptive_softmax',
|
|
**kwargs):
|
|
def _logit(x, W, b, proj):
|
|
y = x
|
|
if x.shape.ndims == 3:
|
|
if proj is not None:
|
|
y = tf.einsum('ibd,ed->ibe', y, proj)
|
|
return tf.einsum('ibd,nd->ibn', y, W) + b
|
|
else:
|
|
if proj is not None:
|
|
y = tf.einsum('id,ed->ie', y, proj)
|
|
return tf.einsum('id,nd->in', y, W) + b
|
|
|
|
params_W, params_projs = params[0], params[1]
|
|
|
|
with tf.variable_scope(scope):
|
|
if len(cutoffs) == 0:
|
|
softmax_b = tf.get_variable('bias', [n_token],
|
|
initializer=tf.zeros_initializer())
|
|
output = _logit(hidden, params_W, softmax_b, params_projs)
|
|
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
|
logits=output)
|
|
nll = tf.reduce_mean(nll)
|
|
else:
|
|
total_loss, total_cnt = 0, 0
|
|
cutoff_ends = [0] + cutoffs + [n_token]
|
|
for i in range(len(cutoff_ends) - 1):
|
|
with tf.variable_scope('cutoff_{}'.format(i)):
|
|
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
|
|
|
cur_d_embed = d_embed // (div_val ** i)
|
|
|
|
if div_val == 1:
|
|
cur_W = params_W[l_idx: r_idx]
|
|
else:
|
|
cur_W = params_W[i]
|
|
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
|
initializer=tf.zeros_initializer())
|
|
if tie_projs[i]:
|
|
if div_val == 1:
|
|
cur_proj = params_projs
|
|
else:
|
|
cur_proj = params_projs[i]
|
|
else:
|
|
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
|
cur_proj = None
|
|
else:
|
|
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
|
initializer=proj_initializer)
|
|
|
|
if i == 0:
|
|
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
|
initializer=tf.zeros_initializer())
|
|
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
|
initializer=tf.zeros_initializer())
|
|
cur_W = tf.concat([cur_W, cluster_W], 0)
|
|
cur_b = tf.concat([cur_b, cluster_b], 0)
|
|
|
|
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
|
|
|
head_target = kwargs.get("head_target")
|
|
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
labels=head_target,
|
|
logits=head_logit)
|
|
|
|
masked_loss = head_nll * perms[i]
|
|
total_loss += tf.reduce_sum(masked_loss)
|
|
total_cnt += tf.reduce_sum(perms[i])
|
|
|
|
# head_logprob = tf.nn.log_softmax(head_logit)
|
|
|
|
# final_logprob = head_logprob * perms[i][:, :, None]
|
|
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
|
|
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
|
|
# total_cnt += tf.reduce_sum(perms[i])
|
|
else:
|
|
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
|
|
|
|
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
|
|
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
|
|
|
|
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
|
|
perms[i])
|
|
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
labels=tf.to_int32(tail_target),
|
|
logits=tail_logit)
|
|
|
|
sum_nll = cur_head_nll + tail_nll
|
|
mask = tf.reduce_sum(perms[i], [0, 1])
|
|
|
|
masked_loss = sum_nll * mask
|
|
total_loss += tf.reduce_sum(masked_loss)
|
|
total_cnt += tf.reduce_sum(mask)
|
|
|
|
nll = total_loss / total_cnt
|
|
|
|
return nll
|
|
|
|
|
|
def _create_mask(qlen, mlen, same_length=False):
|
|
attn_mask = tf.ones([qlen, qlen])
|
|
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
|
|
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
|
|
attn_mask_pad = tf.zeros([qlen, mlen])
|
|
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
|
|
if same_length:
|
|
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
|
|
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
|
|
return ret
|
|
|
|
def _cache_mem(curr_out, prev_mem, mem_len=None):
|
|
if mem_len is None or prev_mem is None:
|
|
new_mem = curr_out
|
|
elif mem_len == 0:
|
|
return prev_mem
|
|
else:
|
|
new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:]
|
|
|
|
return tf.stop_gradient(new_mem)
|
|
|
|
|
|
def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed,
|
|
n_head, d_head, d_inner, dropout, dropatt,
|
|
initializer, is_training, proj_initializer=None,
|
|
mem_len=None, cutoffs=[], div_val=1, tie_projs=[],
|
|
same_length=False, clamp_len=-1, use_tpu=True,
|
|
input_perms=None, target_perms=None, head_target=None,
|
|
untie_r=False, proj_same_dim=True,
|
|
scope='transformer'):
|
|
"""
|
|
cutoffs: a list of python int. Cutoffs for adaptive softmax.
|
|
tie_projs: a list of python bools. Whether to tie the projections.
|
|
use_tpu: if True, use one_hot in embedding lookup and bin-based implementation
|
|
of adaptive softmax.
|
|
perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
|
|
Only used in the adaptive setting.
|
|
"""
|
|
new_mems = []
|
|
with tf.variable_scope(scope):
|
|
if untie_r:
|
|
r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
|
|
initializer=initializer)
|
|
r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
|
|
initializer=initializer)
|
|
else:
|
|
r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
|
|
initializer=initializer)
|
|
r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
|
|
initializer=initializer)
|
|
|
|
qlen = tf.shape(dec_inp)[0]
|
|
mlen = tf.shape(mems[0])[0] if mems is not None else 0
|
|
klen = mlen + qlen
|
|
|
|
if proj_initializer is None:
|
|
proj_initializer = initializer
|
|
lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else
|
|
mask_adaptive_embedding_lookup)
|
|
embeddings, shared_params = lookup_fn(
|
|
x=dec_inp,
|
|
n_token=n_token,
|
|
d_embed=d_embed,
|
|
d_proj=d_model,
|
|
cutoffs=cutoffs,
|
|
initializer=initializer,
|
|
proj_initializer=proj_initializer,
|
|
div_val= div_val,
|
|
perms=input_perms,
|
|
proj_same_dim=proj_same_dim)
|
|
|
|
attn_mask = _create_mask(qlen, mlen, same_length)
|
|
|
|
pos_seq = tf.range(klen - 1, -1, -1.0)
|
|
if clamp_len > 0:
|
|
pos_seq = tf.minimum(pos_seq, clamp_len)
|
|
inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
|
|
pos_emb = positional_embedding(pos_seq, inv_freq)
|
|
|
|
output = tf.layers.dropout(embeddings, dropout, training=is_training)
|
|
pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
|
|
|
|
if mems is None:
|
|
mems = [None] * n_layer
|
|
|
|
for i in range(n_layer):
|
|
# cache new mems
|
|
new_mems.append(_cache_mem(output, mems[i], mem_len))
|
|
|
|
with tf.variable_scope('layer_{}'.format(i)):
|
|
output = rel_multihead_attn(
|
|
w=output,
|
|
r=pos_emb,
|
|
r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
|
|
r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
|
|
attn_mask=attn_mask,
|
|
mems=mems[i],
|
|
d_model=d_model,
|
|
n_head=n_head,
|
|
d_head=d_head,
|
|
dropout=dropout,
|
|
dropatt=dropatt,
|
|
is_training=is_training,
|
|
kernel_initializer=initializer)
|
|
output = positionwise_FF(
|
|
inp=output,
|
|
d_model=d_model,
|
|
d_inner=d_inner,
|
|
dropout=dropout,
|
|
kernel_initializer=initializer,
|
|
is_training=is_training)
|
|
|
|
output = tf.layers.dropout(output, dropout, training=is_training)
|
|
|
|
logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else
|
|
mask_adaptive_logsoftmax)
|
|
loss = logsoftmax_fn(
|
|
hidden=output,
|
|
target=target,
|
|
n_token=n_token,
|
|
d_embed=d_embed,
|
|
d_proj=d_model,
|
|
cutoffs=cutoffs,
|
|
params=shared_params,
|
|
tie_projs=tie_projs,
|
|
initializer=initializer,
|
|
proj_initializer=proj_initializer,
|
|
div_val=div_val,
|
|
perms=target_perms,
|
|
head_target=head_target,
|
|
proj_same_dim=proj_same_dim)
|
|
return loss, new_mems
|
|
|