chore(transformer-xl): Initial commit

This commit is contained in:
Tibo De Peuter 2025-11-07 12:58:13 +01:00
parent ef4684ef39
commit 10512876f2
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
46 changed files with 10547 additions and 0 deletions

View file

@ -0,0 +1,65 @@
import os
import tensorflow as tf
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable":
return ps_dev
else:
return "/gpu:%d" % gpu
return _assign
def average_grads_and_vars(tower_grads_and_vars):
def average_dense(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]
grad = grad_and_vars[0][0]
for g, _ in grad_and_vars[1:]:
grad += g
return grad / len(grad_and_vars)
def average_sparse(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]
indices = []
values = []
for g, _ in grad_and_vars:
indices += [g.indices]
values += [g.values]
indices = tf.concat(indices, 0)
values = tf.concat(values, 0) / len(grad_and_vars)
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
average_grads_and_vars = []
for grad_and_vars in zip(*tower_grads_and_vars):
if grad_and_vars[0][0] is None:
grad = None
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
grad = average_sparse(grad_and_vars)
else:
grad = average_dense(grad_and_vars)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads_and_vars.append(grad_and_var)
return average_grads_and_vars
def load_from_checkpoint(saver, logdir):
sess = tf.get_default_session()
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
if os.path.isabs(ckpt.model_checkpoint_path):
# Restores from checkpoint with absolute path.
saver.restore(sess, ckpt.model_checkpoint_path)
else:
# Restores from checkpoint with relative path.
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
return True
return False