65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
import os
|
|
import tensorflow as tf
|
|
|
|
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
|
|
def _assign(op):
|
|
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
|
if node_def.op == "Variable":
|
|
return ps_dev
|
|
else:
|
|
return "/gpu:%d" % gpu
|
|
return _assign
|
|
|
|
|
|
def average_grads_and_vars(tower_grads_and_vars):
|
|
def average_dense(grad_and_vars):
|
|
if len(grad_and_vars) == 1:
|
|
return grad_and_vars[0][0]
|
|
|
|
grad = grad_and_vars[0][0]
|
|
for g, _ in grad_and_vars[1:]:
|
|
grad += g
|
|
return grad / len(grad_and_vars)
|
|
|
|
def average_sparse(grad_and_vars):
|
|
if len(grad_and_vars) == 1:
|
|
return grad_and_vars[0][0]
|
|
|
|
indices = []
|
|
values = []
|
|
for g, _ in grad_and_vars:
|
|
indices += [g.indices]
|
|
values += [g.values]
|
|
indices = tf.concat(indices, 0)
|
|
values = tf.concat(values, 0) / len(grad_and_vars)
|
|
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
|
|
|
|
average_grads_and_vars = []
|
|
for grad_and_vars in zip(*tower_grads_and_vars):
|
|
if grad_and_vars[0][0] is None:
|
|
grad = None
|
|
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
|
|
grad = average_sparse(grad_and_vars)
|
|
else:
|
|
grad = average_dense(grad_and_vars)
|
|
# Keep in mind that the Variables are redundant because they are shared
|
|
# across towers. So .. we will just return the first tower's pointer to
|
|
# the Variable.
|
|
v = grad_and_vars[0][1]
|
|
grad_and_var = (grad, v)
|
|
average_grads_and_vars.append(grad_and_var)
|
|
return average_grads_and_vars
|
|
|
|
|
|
def load_from_checkpoint(saver, logdir):
|
|
sess = tf.get_default_session()
|
|
ckpt = tf.train.get_checkpoint_state(logdir)
|
|
if ckpt and ckpt.model_checkpoint_path:
|
|
if os.path.isabs(ckpt.model_checkpoint_path):
|
|
# Restores from checkpoint with absolute path.
|
|
saver.restore(sess, ckpt.model_checkpoint_path)
|
|
else:
|
|
# Restores from checkpoint with relative path.
|
|
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
|
|
return True
|
|
return False
|