This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../transformer-xl/tf/model.py

546 lines
21 KiB
Python

import tensorflow as tf
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
if bsz is not None:
return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
else:
return pos_emb[:, None, :]
def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer,
scope='ff', is_training=True):
output = inp
with tf.variable_scope(scope):
output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu,
kernel_initializer=kernel_initializer,
name='layer_1')
output = tf.layers.dropout(output, dropout, training=is_training,
name='drop_1')
output = tf.layers.dense(output, d_model,
kernel_initializer=kernel_initializer,
name='layer_2')
output = tf.layers.dropout(output, dropout, training=is_training,
name='drop_2')
output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1)
return output
def rel_shift(x):
x_size = tf.shape(x)
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
n_head, d_head, dropout, dropatt, is_training,
kernel_initializer, scope='rel_attn'):
scale = 1 / (d_head ** 0.5)
with tf.variable_scope(scope):
qlen = tf.shape(w)[0]
rlen = tf.shape(r)[0]
bsz = tf.shape(w)[1]
cat = tf.concat([mems, w],
0) if mems is not None and mems.shape.ndims > 1 else w
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='qkv')
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='r')
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
w_head_q = w_head_q[-qlen:]
klen = tf.shape(w_head_k)[0]
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
rw_head_q = w_head_q + r_w_bias
rr_head_q = w_head_q + r_r_bias
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
BD = rel_shift(BD)
attn_score = (AC + BD) * scale
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
attn_prob = tf.nn.softmax(attn_score, 1)
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
size_t = tf.shape(attn_vec)
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
kernel_initializer=kernel_initializer, name='o')
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
return output
def embedding_lookup(lookup_table, x, use_tpu=True):
if use_tpu:
n_token = tf.shape(lookup_table)[0]
one_hot_idx = tf.one_hot(x, n_token)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
else:
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
else:
return tf.nn.embedding_lookup(lookup_table, x)
def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
proj_initializer, div_val=1,
proj_same_dim=True,
scope='adaptive_embed', **kwargs):
emb_scale = d_proj ** 0.5
with tf.variable_scope(scope):
if div_val == 1:
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
initializer=initializer)
y = embedding_lookup(lookup_table, x, use_tpu=False)
if d_proj != d_embed:
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
initializer=proj_initializer)
y = tf.einsum('ibe,ed->ibd', y, proj_W)
else:
proj_W = None
ret_params = [lookup_table, proj_W]
else:
tables, projs = [], []
cutoff_ends = [0] + cutoffs + [n_token]
x_size = tf.shape(x)
y = tf.zeros([x_size[0], x_size[1], d_proj])
for i in range(len(cutoff_ends) - 1):
with tf.variable_scope('cutoff_{}'.format(i)):
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
mask = (x >= l_idx) & (x < r_idx)
cur_x = tf.boolean_mask(x, mask) - l_idx
cur_d_embed = d_embed // (div_val ** i)
lookup_table = tf.get_variable('lookup_table',
[r_idx - l_idx, cur_d_embed],
initializer=initializer)
cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False)
if d_proj == cur_d_embed and not proj_same_dim:
proj_W = None
else:
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
initializer=proj_initializer)
cur_y = tf.einsum('id,de->ie', cur_y, proj_W)
mask_idx = tf.to_int64(tf.where(mask))
y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y)))
tables.append(lookup_table)
projs.append(proj_W)
ret_params = [tables, projs]
y *= emb_scale
return y, ret_params
def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
proj_initializer, div_val=1, perms=None,
proj_same_dim=True,
scope='adaptive_embed'):
"""
perms: If None, first compute W = W1 x W2 (projection for each bin),
and then compute X x W (embedding lookup). If not None,
use bin-based embedding lookup with max_bin_size defined by
the shape of perms.
"""
emb_scale = d_proj ** 0.5
with tf.variable_scope(scope):
if div_val == 1:
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
initializer=initializer)
y = embedding_lookup(lookup_table, x)
if d_proj != d_embed:
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
initializer=proj_initializer)
y = tf.einsum('ibe,ed->ibd', y, proj_W)
else:
proj_W = None
ret_params = [lookup_table, proj_W]
else:
tables, projs = [], []
cutoff_ends = [0] + cutoffs + [n_token]
x_size = tf.shape(x)
if perms is None:
cat_lookup = []
else:
cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj])
for i in range(len(cutoff_ends) - 1):
with tf.variable_scope('cutoff_{}'.format(i)):
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
cur_d_embed = d_embed // (div_val ** i)
lookup_table = tf.get_variable('lookup_table',
[r_idx - l_idx, cur_d_embed],
initializer=initializer)
if cur_d_embed == d_proj and not proj_same_dim:
proj_W = None
else:
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
initializer=proj_initializer)
if perms is None:
cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W))
else:
# speed up the computation of the first bin
# also save some meory
if i == 0:
cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1))
if proj_W is not None:
cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W)
cur_y *= perms[i][:, :, None]
cat_lookup += cur_y
else:
cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i])
cur_x = tf.to_int32(cur_x)
cur_y = embedding_lookup(lookup_table, cur_x)
if proj_W is not None:
cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W)
cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i])
tables.append(lookup_table)
projs.append(proj_W)
if perms is None:
cat_lookup = tf.concat(cat_lookup, 0)
y = embedding_lookup(cat_lookup, x)
else:
y = cat_lookup
ret_params = [tables, projs]
y *= emb_scale
return y, ret_params
def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
params, tie_projs,
initializer=None, proj_initializer=None,
div_val=1, scope='adaptive_softmax',
proj_same_dim=True,
return_mean=True, **kwargs):
def _logit(x, W, b, proj):
y = x
if proj is not None:
y = tf.einsum('ibd,ed->ibe', y, proj)
return tf.einsum('ibd,nd->ibn', y, W) + b
params_W, params_projs = params[0], params[1]
def _gather_logprob(logprob, target):
lp_size = tf.shape(logprob)
r = tf.range(lp_size[0])
idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx)
with tf.variable_scope(scope):
if len(cutoffs) == 0:
softmax_b = tf.get_variable('bias', [n_token],
initializer=tf.zeros_initializer())
output = _logit(hidden, params_W, softmax_b, params_projs)
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
logits=output)
else:
cutoff_ends = [0] + cutoffs + [n_token]
nll = tf.zeros_like(target, dtype=tf.float32)
for i in range(len(cutoff_ends) - 1):
with tf.variable_scope('cutoff_{}'.format(i)):
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
mask = (target >= l_idx) & (target < r_idx)
mask_idx = tf.where(mask)
cur_target = tf.boolean_mask(target, mask) - l_idx
cur_d_embed = d_embed // (div_val ** i)
if div_val == 1:
cur_W = params_W[l_idx: r_idx]
else:
cur_W = params_W[i]
cur_b = tf.get_variable('b', [r_idx - l_idx],
initializer=tf.zeros_initializer())
if tie_projs[i]:
if div_val == 1:
cur_proj = params_projs
else:
cur_proj = params_projs[i]
else:
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
cur_proj = None
else:
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
initializer=proj_initializer)
if i == 0:
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
initializer=tf.zeros_initializer())
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
initializer=tf.zeros_initializer())
cur_W = tf.concat([cur_W, cluster_W], 0)
cur_b = tf.concat([cur_b, cluster_b], 0)
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
head_logprob = tf.nn.log_softmax(head_logit)
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
cur_logprob = _gather_logprob(cur_head_logprob, cur_target)
else:
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
cur_hidden = tf.boolean_mask(hidden, mask)
tail_logit = tf.squeeze(_logit(
cur_hidden[None], cur_W, cur_b, cur_proj), 0)
tail_logprob = tf.nn.log_softmax(tail_logit)
cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] +
_gather_logprob(tail_logprob, cur_target))
nll += tf.scatter_nd(mask_idx, -cur_logprob,
tf.to_int64(tf.shape(nll)))
if return_mean:
nll = tf.reduce_mean(nll)
return nll
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
params, tie_projs,
initializer=None, proj_initializer=None,
div_val=1, perms=None, proj_same_dim=True,
scope='adaptive_softmax',
**kwargs):
def _logit(x, W, b, proj):
y = x
if x.shape.ndims == 3:
if proj is not None:
y = tf.einsum('ibd,ed->ibe', y, proj)
return tf.einsum('ibd,nd->ibn', y, W) + b
else:
if proj is not None:
y = tf.einsum('id,ed->ie', y, proj)
return tf.einsum('id,nd->in', y, W) + b
params_W, params_projs = params[0], params[1]
with tf.variable_scope(scope):
if len(cutoffs) == 0:
softmax_b = tf.get_variable('bias', [n_token],
initializer=tf.zeros_initializer())
output = _logit(hidden, params_W, softmax_b, params_projs)
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
logits=output)
nll = tf.reduce_mean(nll)
else:
total_loss, total_cnt = 0, 0
cutoff_ends = [0] + cutoffs + [n_token]
for i in range(len(cutoff_ends) - 1):
with tf.variable_scope('cutoff_{}'.format(i)):
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
cur_d_embed = d_embed // (div_val ** i)
if div_val == 1:
cur_W = params_W[l_idx: r_idx]
else:
cur_W = params_W[i]
cur_b = tf.get_variable('b', [r_idx - l_idx],
initializer=tf.zeros_initializer())
if tie_projs[i]:
if div_val == 1:
cur_proj = params_projs
else:
cur_proj = params_projs[i]
else:
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
cur_proj = None
else:
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
initializer=proj_initializer)
if i == 0:
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
initializer=tf.zeros_initializer())
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
initializer=tf.zeros_initializer())
cur_W = tf.concat([cur_W, cluster_W], 0)
cur_b = tf.concat([cur_b, cluster_b], 0)
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
head_target = kwargs.get("head_target")
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=head_target,
logits=head_logit)
masked_loss = head_nll * perms[i]
total_loss += tf.reduce_sum(masked_loss)
total_cnt += tf.reduce_sum(perms[i])
# head_logprob = tf.nn.log_softmax(head_logit)
# final_logprob = head_logprob * perms[i][:, :, None]
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
# total_cnt += tf.reduce_sum(perms[i])
else:
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
perms[i])
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.to_int32(tail_target),
logits=tail_logit)
sum_nll = cur_head_nll + tail_nll
mask = tf.reduce_sum(perms[i], [0, 1])
masked_loss = sum_nll * mask
total_loss += tf.reduce_sum(masked_loss)
total_cnt += tf.reduce_sum(mask)
nll = total_loss / total_cnt
return nll
def _create_mask(qlen, mlen, same_length=False):
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret
def _cache_mem(curr_out, prev_mem, mem_len=None):
if mem_len is None or prev_mem is None:
new_mem = curr_out
elif mem_len == 0:
return prev_mem
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:]
return tf.stop_gradient(new_mem)
def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed,
n_head, d_head, d_inner, dropout, dropatt,
initializer, is_training, proj_initializer=None,
mem_len=None, cutoffs=[], div_val=1, tie_projs=[],
same_length=False, clamp_len=-1, use_tpu=True,
input_perms=None, target_perms=None, head_target=None,
untie_r=False, proj_same_dim=True,
scope='transformer'):
"""
cutoffs: a list of python int. Cutoffs for adaptive softmax.
tie_projs: a list of python bools. Whether to tie the projections.
use_tpu: if True, use one_hot in embedding lookup and bin-based implementation
of adaptive softmax.
perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
Only used in the adaptive setting.
"""
new_mems = []
with tf.variable_scope(scope):
if untie_r:
r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
initializer=initializer)
r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
initializer=initializer)
else:
r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
initializer=initializer)
r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
initializer=initializer)
qlen = tf.shape(dec_inp)[0]
mlen = tf.shape(mems[0])[0] if mems is not None else 0
klen = mlen + qlen
if proj_initializer is None:
proj_initializer = initializer
lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else
mask_adaptive_embedding_lookup)
embeddings, shared_params = lookup_fn(
x=dec_inp,
n_token=n_token,
d_embed=d_embed,
d_proj=d_model,
cutoffs=cutoffs,
initializer=initializer,
proj_initializer=proj_initializer,
div_val= div_val,
perms=input_perms,
proj_same_dim=proj_same_dim)
attn_mask = _create_mask(qlen, mlen, same_length)
pos_seq = tf.range(klen - 1, -1, -1.0)
if clamp_len > 0:
pos_seq = tf.minimum(pos_seq, clamp_len)
inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
pos_emb = positional_embedding(pos_seq, inv_freq)
output = tf.layers.dropout(embeddings, dropout, training=is_training)
pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
if mems is None:
mems = [None] * n_layer
for i in range(n_layer):
# cache new mems
new_mems.append(_cache_mem(output, mems[i], mem_len))
with tf.variable_scope('layer_{}'.format(i)):
output = rel_multihead_attn(
w=output,
r=pos_emb,
r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
attn_mask=attn_mask,
mems=mems[i],
d_model=d_model,
n_head=n_head,
d_head=d_head,
dropout=dropout,
dropatt=dropatt,
is_training=is_training,
kernel_initializer=initializer)
output = positionwise_FF(
inp=output,
d_model=d_model,
d_inner=d_inner,
dropout=dropout,
kernel_initializer=initializer,
is_training=is_training)
output = tf.layers.dropout(output, dropout, training=is_training)
logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else
mask_adaptive_logsoftmax)
loss = logsoftmax_fn(
hidden=output,
target=target,
n_token=n_token,
d_embed=d_embed,
d_proj=d_model,
cutoffs=cutoffs,
params=shared_params,
tie_projs=tie_projs,
initializer=initializer,
proj_initializer=proj_initializer,
div_val=div_val,
perms=target_perms,
head_target=head_target,
proj_same_dim=proj_same_dim)
return loss, new_mems