chore(transformer-xl): Initial commit
This commit is contained in:
parent
ef4684ef39
commit
10512876f2
46 changed files with 10547 additions and 0 deletions
65
transformer-xl/tf/gpu_utils.py
Normal file
65
transformer-xl/tf/gpu_utils.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
|
||||
def _assign(op):
|
||||
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
||||
if node_def.op == "Variable":
|
||||
return ps_dev
|
||||
else:
|
||||
return "/gpu:%d" % gpu
|
||||
return _assign
|
||||
|
||||
|
||||
def average_grads_and_vars(tower_grads_and_vars):
|
||||
def average_dense(grad_and_vars):
|
||||
if len(grad_and_vars) == 1:
|
||||
return grad_and_vars[0][0]
|
||||
|
||||
grad = grad_and_vars[0][0]
|
||||
for g, _ in grad_and_vars[1:]:
|
||||
grad += g
|
||||
return grad / len(grad_and_vars)
|
||||
|
||||
def average_sparse(grad_and_vars):
|
||||
if len(grad_and_vars) == 1:
|
||||
return grad_and_vars[0][0]
|
||||
|
||||
indices = []
|
||||
values = []
|
||||
for g, _ in grad_and_vars:
|
||||
indices += [g.indices]
|
||||
values += [g.values]
|
||||
indices = tf.concat(indices, 0)
|
||||
values = tf.concat(values, 0) / len(grad_and_vars)
|
||||
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
|
||||
|
||||
average_grads_and_vars = []
|
||||
for grad_and_vars in zip(*tower_grads_and_vars):
|
||||
if grad_and_vars[0][0] is None:
|
||||
grad = None
|
||||
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
|
||||
grad = average_sparse(grad_and_vars)
|
||||
else:
|
||||
grad = average_dense(grad_and_vars)
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads_and_vars.append(grad_and_var)
|
||||
return average_grads_and_vars
|
||||
|
||||
|
||||
def load_from_checkpoint(saver, logdir):
|
||||
sess = tf.get_default_session()
|
||||
ckpt = tf.train.get_checkpoint_state(logdir)
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
if os.path.isabs(ckpt.model_checkpoint_path):
|
||||
# Restores from checkpoint with absolute path.
|
||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
else:
|
||||
# Restores from checkpoint with relative path.
|
||||
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
|
||||
return True
|
||||
return False
|
||||
Reference in a new issue