# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # =================================================================== """TPUEstimator class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import os import signal import sys import threading import time import numpy as np import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin import math try: import google3 from google3.third_party.tensorflow.contrib.tpu.python.ops import tpu_ops from google3.third_party.tensorflow.contrib.tpu.python.tpu import error_handling from google3.third_party.tensorflow.contrib.tpu.python.tpu import session_support from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_config from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_context from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_feed from google3.third_party.tensorflow.contrib.tpu.python.tpu import training_loop from google3.third_party.tensorflow.contrib.tpu.python.tpu import util as util_lib from google3.third_party.tensorflow.contrib.training.python.training import hparam from google3.third_party.tensorflow.core.framework import variable_pb2 from google3.third_party.tensorflow.core.framework.summary_pb2 import Summary from google3.third_party.tensorflow.core.protobuf import config_pb2 from google3.third_party.tensorflow.python.data.ops import dataset_ops from google3.third_party.tensorflow.python.data.util import nest as data_nest from google3.third_party.tensorflow.python.estimator import estimator as estimator_lib from google3.third_party.tensorflow.python.estimator import model_fn as model_fn_lib from google3.third_party.tensorflow.python.estimator.export import export_output as export_output_lib from google3.third_party.tensorflow.python.framework import constant_op from google3.third_party.tensorflow.python.framework import dtypes from google3.third_party.tensorflow.python.framework import errors from google3.third_party.tensorflow.python.framework import ops from google3.third_party.tensorflow.python.ops import array_ops from google3.third_party.tensorflow.python.ops import check_ops from google3.third_party.tensorflow.python.ops import control_flow_ops from google3.third_party.tensorflow.python.ops import init_ops from google3.third_party.tensorflow.python.ops import math_ops from google3.third_party.tensorflow.python.ops import resource_variable_ops from google3.third_party.tensorflow.python.ops import state_ops from google3.third_party.tensorflow.python.ops import summary_ops_v2 as contrib_summary from google3.third_party.tensorflow.python.ops import variable_scope from google3.third_party.tensorflow.python.ops import variables from google3.third_party.tensorflow.python.platform import tf_logging as logging from google3.third_party.tensorflow.python.saved_model import tag_constants from google3.third_party.tensorflow.python.summary import summary from google3.third_party.tensorflow.python.training import basic_session_run_hooks from google3.third_party.tensorflow.python.training import evaluation from google3.third_party.tensorflow.python.training import session_run_hook from google3.third_party.tensorflow.python.training import training from google3.third_party.tensorflow.python.training import training_util from google3.third_party.tensorflow.python.util import function_utils from google3.third_party.tensorflow.python.util import nest from google3.third_party.tensorflow.python.util import tf_inspect except: import tensorflow from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import summary_ops_v2 as contrib_summary from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'custom_tpu_estimator' # CHANGE FOR RECURRENCY _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' _USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' # Ideally _USE_TPU_KEY should be reserved as well. However there are already # models that make use of this key, thus it can not be reserved now to prevent # breakage. In the long run, we would like to mitigate this by migrating models # off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] # TODO(b/65703635): Flip the value and remove all dead code. Currently, this is # only used for per-core based deployments. For per-host based pipelines, if a # user returns a Dataset instance it will be automatically wrapped in a # tf.while_loop (This can be disabled by returning features and labels # explicitly). _WRAP_INPUT_FN_INTO_WHILE_LOOP = False ops.register_proto_function( '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), proto_type=variable_pb2.VariableDef, to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, use_resource=True, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) def _create_or_get_iterations_per_loop(): """Creates or gets the iterations_per_loop variable. In TPUEstimator, the user provided computation, the model_fn, is wrapped inside a tf.while_loop for peak performance. The iterations of the loop are specified by this variable, which adjusts its value on the CPU after each TPU program execution and before the next TPU execution. The purpose of using a variable, rather then a constant, is to allow TPUEstimator adapt the TPU training iterations according to the final steps specified by users. For example, if the user sets the iterations_per_loop as 4 in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop variable will have the following value before each TPU training. - 1-th TPU execution: iterations_per_loop = 4 - 2-th TPU execution: iterations_per_loop = 4 - 3-th TPU execution: iterations_per_loop = 2 As model_fn increases the global step once per train_op invocation, the global step is 10 after all TPU executions, matching the steps=10 inputs passed in by users. Returns: A TF non-trainable resource variable. Raises: RuntimeError: If multi iterations_per_loop variables were found. """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) if len(iter_vars) == 1: return iter_vars[0] elif len(iter_vars) > 1: raise RuntimeError('Multiple iterations_per_loop_var in collection.') with ops.colocate_with(training_util.get_global_step()): with variable_scope.variable_scope( _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): return variable_scope.get_variable( _ITERATIONS_PER_LOOP_VAR, initializer=init_ops.zeros_initializer(), shape=[], dtype=dtypes.int32, trainable=False, collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], use_resource=True) def _sync_variables_ops(): # Gets the variables back from TPU nodes. This means the variables updated # by TPU will now be *synced* to host memory. return [ array_ops.check_numerics(v.read_value(), 'Gradient for %s is NaN' % v.name).op for v in variables.trainable_variables() ] def _increase_eval_step_op(iterations_per_loop): """Returns an op to increase the eval step for TPU evaluation. Args: iterations_per_loop: Tensor. The number of eval steps running in TPU system before returning to CPU host for each `Session.run`. Returns: An operation """ eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access # Estimator evaluate increases 1 by default. So, we increase the difference. return state_ops.assign_add( eval_step, math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), use_locking=True) def _extract_key_names(tensor_or_dict): if isinstance(tensor_or_dict, dict): return sorted(tensor_or_dict.keys()) return [] class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. All preserved signals must be negative numbers. Positive numbers are used to indicate the number of iterations for next training/evaluation loop. """ NEXT_BATCH = -1 STOP = -2 class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and `export_outputs`. For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from TPU system to CPU host. All tensors must have be batch-major, i.e., the batch size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is a dict. `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This function should not capture any Tensors in `model_fn`. `host_call` is a tuple of a `function` and a list or dictionary of `tensors` to pass to that function and returns a list of Tensors. `host_call` currently works for train() and evaluate(). The Tensors returned by the function is executed on the CPU on every step, so there is communication overhead when sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The `tensors` are concatenated along their major (batch) dimension, and so must be >= rank 1. The `host_call` is useful for writing summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: host_calls['eval_metrics'] = eval_metrics if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) training_hooks = list(training_hooks or []) evaluation_hooks = list(evaluation_hooks or []) prediction_hooks = list(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be SessionRunHook instances, given: {}'.format( hook)) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, host_call=host_call, training_hooks=training_hooks, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" host_calls = {} if self.eval_metrics is not None: host_calls['eval_metrics'] = self.eval_metrics if self.host_call is not None: host_calls['host_call'] = self.host_call host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) eval_metric_ops = None if self.eval_metrics is not None: eval_metric_ops = host_call_ret['eval_metrics'] hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] hooks = list(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, predictions=self.predictions, loss=self.loss, train_op=self.train_op, eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, scaffold=scaffold, training_hooks=self.training_hooks + hooks, evaluation_hooks=self.evaluation_hooks + hooks, prediction_hooks=self.prediction_hooks + hooks) class _OpQueueContext(object): """Manages work queue and thread for a infeed/outfeed thread.""" def __init__(self, name, target, args): self._name = name self._queue = Queue.Queue() args = (self,) + args self._thread = threading.Thread(name=name, target=target, args=args) self._thread.daemon = True self._thread.start() def stop(self): self._queue.put(_SIGNAL.STOP) def send_next_batch_signal(self, iterations): self._queue.put(iterations) def read_iteration_counts(self): while True: iterations = self._queue.get(block=True) logging.debug('%s read iterations %s', self._name, iterations) if iterations == _SIGNAL.STOP: logging.info('%s received shutdown signal, stopping.', self._name) return yield iterations def join(self): logging.info('Shutting down %s thread.' % self._name) self.stop() self._thread.join() class _OpSignalOnceQueueContext(_OpQueueContext): """Manages work queue and thread for a infeed/outfeed thread. This subclass only signals once. """ def __init__(self, name, target, args): super(_OpSignalOnceQueueContext, self).__init__(name, target, args) self._has_signaled = False def send_next_batch_signal(self, iterations): if not self._has_signaled: self._queue.put(iterations) self._has_signaled = True class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): """A Session hook setting up the TPU initialization, infeed, and outfeed. This hook does two major things: 1. initialize and shutdown TPU system. 2. launch and join the threads for infeed enqueue and (optional) outfeed dequeue. """ def __init__(self, ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=True, rendezvous=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) self._feed_error = None self._finished = False def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() self._init_ops = [tpu.initialize_system(job=self._master_job)] self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() self._init_ops.extend(summary_writer_init_ops) # Get all the writer resources from the initializer, so we know what to # flush. for op in summary_writer_init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) def _run_infeed(self, queue_ctx, session): logging.info('Starting infeed thread controller.') if self._initial_infeed_sleep_secs: logging.info('%s thread sleeping for %d seconds.', self._name, self._initial_infeed_sleep_secs) time.sleep(self._initial_infeed_sleep_secs) logging.info('%s thread starting after sleep', self._name) with self._rendezvous.catch_errors(source='infeed', session=session): if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) session.run(self._enqueue_ops) else: for _ in queue_ctx.read_iteration_counts(): session.run(self._enqueue_ops) logging.info('Infeed thread finished, shutting down.') def _run_outfeed(self, queue_ctx, session): logging.info('Starting outfeed thread controller.') with self._rendezvous.catch_errors(source='outfeed', session=session): for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) logging.info('Outfeed thread finished, shutting down.') def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) def after_create_session(self, session, coord): logging.info('Init TPU system') session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) self._outfeed_controller = _OpQueueContext( name='OutfeedController', target=self._run_outfeed, args=(session,)) def before_run(self, run_context): self._feed_error = None iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) self._infeed_controller.send_next_batch_signal(iterations) logging.info('Dequeue next (%d) batch(es) of data from outfeed.', iterations) self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): self._finished = True logging.info('Stop infeed thread controller') self._infeed_controller.join() self._rendezvous.record_done('infeed') logging.info('Stop output thread controller') self._outfeed_controller.join() self._rendezvous.record_done('outfeed') logging.info('Shutdown TPU system.') session.run(self._finalize_ops) class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, rendezvous=rendezvous) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) class _TPUStopAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step. This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with following differences for TPU training: 1. This hook sets the variable for iterations_per_loop, which is used by `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. As the hook execution order is not guaranteed, the variable update is handled in `after_create_session` and `after_run` as `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. 2. For each training loop (session.run), the global step could be increased multiple times on TPU. The global step tensor value will be explicitly read again in `after_run` to ensure the latest value is retrieved to avoid race condition. """ def __init__(self, iterations, num_steps=None, last_step=None): """Initializes a `StopAtStepHook`. Args: iterations: The number of iterations to run optimizer per training loop. num_steps: Number of steps to execute. last_step: Step after which to stop. Raises: ValueError: If one of the arguments is invalid. """ if num_steps is None and last_step is None: raise ValueError('One of num_steps or last_step must be specified.') if num_steps is not None and last_step is not None: raise ValueError('Only one of num_steps or last_step can be specified.') self._num_steps = num_steps self._last_step = last_step self._iterations = iterations def _next_iterations(self, global_step, last_step): gap = last_step - global_step return min(gap, self._iterations) def begin(self): self._global_step_tensor = training_util.get_global_step() if self._global_step_tensor is None: raise RuntimeError('Global step should be created.') self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) if self._last_step is None: self._last_step = global_step + self._num_steps iterations = self._next_iterations(global_step, self._last_step) self._iterations_per_loop_var.load(iterations, session=session) def after_run(self, run_context, run_values): # Global step cannot be retrieved via SessionRunArgs and before_run due to # race condition. global_step = run_context.session.run(self._global_step_tensor) if global_step >= self._last_step: run_context.request_stop() else: iterations = self._next_iterations(global_step, self._last_step) self._iterations_per_loop_var.load( iterations, session=run_context.session) class _SetEvalIterationsHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" def __init__(self, num_steps): """Initializes a `_SetEvalIterationsHook`. Args: num_steps: Number of steps to execute. """ self._num_steps = num_steps def begin(self): self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): self._iterations_per_loop_var.load(self._num_steps, session=session) class _StoppingPredictHook(session_run_hook.SessionRunHook): """Hook that requests stop according to the stopping signal in prediction.""" def __init__(self, scalar_stopping_signal): self._scalar_stopping_signal = scalar_stopping_signal def begin(self): self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): # This is not necessary as we do not run infeed enqueue and outfeed dequeue # in side threads for prediction model. But it makes the # TPUInfeedOutfeedSessionHook prints nice message. self._iterations_per_loop_var.load(1, session=session) def before_run(self, run_context): return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) def after_run(self, run_context, run_values): _ = run_context scalar_stopping_signal = run_values.results if _StopSignals.should_stop(scalar_stopping_signal): # NOTE(xiejw): In prediction, stopping signals are inserted for each # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # # batch 0: images, labels, stop = 0 (user provided) # batch 1: images, labels, stop = 0 (user provided) # ... # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') def generate_per_core_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A fn returns enqueue_ops.""" num_cores_per_host = ctx.num_of_cores_per_host per_host_sharded_inputs = [] for core_ordinal in range(num_cores_per_host): with ops.name_scope('ordinal_%d' % (core_ordinal)): user_context = tpu_context.TPUContext( internal_ctx=ctx, input_device=host_device, invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal ) inputs = _Inputs.from_input_fn(input_fn(user_context)) if inputs.is_dataset: raise TypeError( '`input_fn` returning `Dataset` is not yet supported in ' 'per-Core input pipeline deployment yet. Please set ' 'TPUConfig.per_host_input_for_training to True or return ' '`features` and `labels` from `input_fn`') features, labels = inputs.features_and_labels() inputs_structure_recorder.validate_and_record_structure( features, labels) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels)) per_host_sharded_inputs.append(flattened_inputs) infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue def generate_per_host_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): """Generates infeed enqueue ops for per-host input_fn on a single host.""" captured_infeed_queue = _CapturedObject() hooks = [] with ops.DEVICE(device): user_context = tpu_context.TPUContext( internal_ctx=ctx, input_device=device, invocation_index=host_id) inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if ctx.mode == model_fn_lib.ModeKeys.PREDICT: if not is_dataset: raise TypeError( 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' '`features` and `labels`.') if batch_axis is not None: raise TypeError('For mode PREDICT, batch_axis is not supported yet.') inputs = _InputsWithStoppingSignals( dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, add_padding=True) if is_dataset: hooks.append(inputs.dataset_initializer_hook()) tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A Fn returning the TPU infeed enqueue ops. By providing as a Fn, it can be invoked inside the tf.while_loop such that the input pipeline for multiple iterations can be executed by one Session.run call. Returns: list of dict of ops. """ with ops.DEVICE(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a # dataset, it is initialized and the features and labels extracted via # `dataset.iterator.get_next()` features, labels = inputs.features_and_labels() signals = inputs.signals() inputs_structure_recorder.validate_and_record_structure(features, labels) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) infeed_queue = tpu_feed.InfeedQueue( tuple_types=[t.dtype for t in unsharded_tensor_list], tuple_shapes=[t.shape for t in unsharded_tensor_list], shard_dimensions=batch_axis) captured_infeed_queue.capture(infeed_queue) infeed_queue.set_number_of_shards(num_of_replicas_per_host) per_host_enqueue_ops = ( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device, tpu_ordinal_function=tpu_ordinal_function_impl)) if signals is None: return per_host_enqueue_ops else: return { 'ops': per_host_enqueue_ops, 'signals': signals, } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset def generate_per_host_v2_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, device, host_id): """Generates infeed enqueue ops for per-host input_fn on a single host.""" captured_infeed_queue = _CapturedObject() hooks = [] with ops.DEVICE(device): user_context = tpu_context.TPUContext( internal_ctx=ctx, input_device=device, invocation_index=host_id) inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if not is_dataset: raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' 'input pipeline configuration.') if ctx.mode == model_fn_lib.ModeKeys.PREDICT: inputs = _InputsWithStoppingSignals( dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, add_padding=True, num_invocations_per_step=ctx.num_of_replicas_per_host) hooks.append(inputs.dataset_initializer_hook()) tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """Generates the per_host enqueue ops.""" control_deps = [] per_host_sharded_inputs = [] num_replicas_per_host = ctx.num_of_replicas_per_host cached_signals = None with ops.DEVICE(device): if not inputs.is_dataset: raise TypeError('`input_fn` must return a `Dataset` for this mode.') for _ in range(num_replicas_per_host): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): features, labels = inputs.features_and_labels() # Calls get_next() signals = inputs.signals() # All the replicas share the replica 0's stopping singal. # This avoids inconsistent state among different model replcias. if cached_signals: signals['stopping'] = cached_signals['stopping'] else: cached_signals = signals inputs_structure_recorder.validate_and_record_structure( features, labels) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if inputs_structure_recorder.flattened_input_dims: input_partition_dims = inputs_structure_recorder.flattened_input_dims if signals: input_partition_dims += [None] * len(signals) # pylint: disable=protected-access infeed_queue = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=input_partition_dims, device_assignment=ctx.device_assignment) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs) else: infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) captured_infeed_queue.capture(infeed_queue) if signals is None: return per_host_enqueue_ops else: return { 'ops': per_host_enqueue_ops, 'signals': signals, } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, num_hosts): """Generates infeed enqueue ops for one input_fn on all the hosts.""" captured_infeed_queue = _CapturedObject() hooks = [] device_0 = ctx.tpu_host_placement_function(host_id=0) with ops.DEVICE(device_0): user_context = tpu_context.TPUContext( internal_ctx=ctx, input_device=device_0, invocation_index=0) inputs = _Inputs.from_input_fn(input_fn(user_context)) is_dataset = inputs.is_dataset if ctx.mode == model_fn_lib.ModeKeys.PREDICT: if not is_dataset: raise TypeError( 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' '`features` and `labels`.') inputs = _InputsWithStoppingSignals( dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, add_padding=True) if is_dataset: hooks.append(inputs.dataset_initializer_hook()) num_replicas_per_host = ctx.num_of_replicas_per_host def tpu_ordinal_function_impl(replica_id): if ctx.device_assignment: return ctx.device_assignment.tpu_ordinal(replica=replica_id) else: return replica_id % num_replicas_per_host def device_function_impl(replica_id): return ctx.tpu_host_placement_function(replica_id=replica_id) def enqueue_ops_fn(): """Generates enqueue ops for all the hosts.""" broadcasted_inputs = [] flattened_inputs = None # Cache result from input_fn. signals = None for host_id in xrange(num_hosts): with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)): for _ in xrange(ctx.num_of_replicas_per_host): # Note: input_fn is only called once at host 0 for the first replica. # The features and labels returned from that invocation are # broadcasted to other replicas(including the replicas on other # hosts). if flattened_inputs is None: features, labels = inputs.features_and_labels() # Calls get_next() signals = inputs.signals() inputs_structure_recorder.validate_and_record_structure( features, labels) flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels, signals)) broadcasted_inputs.append(flattened_inputs) infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(broadcasted_inputs[0])) captured_infeed_queue.capture(infeed_queue) enqueue_ops = infeed_queue.generate_enqueue_ops( broadcasted_inputs, tpu_ordinal_function=tpu_ordinal_function_impl, placement_function=device_function_impl) if signals is None: return enqueue_ops else: return { 'ops': enqueue_ops, 'signals': signals, } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from call site. To be precise, based on the configuration in `_InternalTPUContext`, it invokes `input_fn` for all cores (usually multi-host TPU training) or for one host (usually for single-host TPU evaluation), and sends all `features` and `labels` returned by `input_fn` to TPU infeed. For per-core invocation, `features` and `labels` are piped to infeed directly, one tuple for each core. For per-host invocation, `features` and `labels` are split at host (with respect to `batch_axis`) and piped to all cores accordingly. In addition, flatten/unflatten are handled by `_InputPipeline` also. Model inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) 3. ((arbitrarily nested structure of features), labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separately to underlying methods. For TPU training, TPUEstimator may expect multiple `features` and `labels` tuples one for each core. TPUEstimator allows various different structures for inputs (namely `features` and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`, or nested tuples and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list. So, `features` and `labels` need to be flattened, before infeed enqueue, and the structure of them needs to be recorded, in order to restore them after infeed dequeue. """ class InputsStructureRecorder(object): """The recorder to record inputs structure.""" def __init__(self, input_partition_dims=None): # Holds the structure of inputs self._feature_structure = {} self._flattened_input_dims = None if input_partition_dims: # This should have been validated in TPUConfig. assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' if len(input_partition_dims) == 2: self._feature_dims, self._label_dims = input_partition_dims else: self._feature_dims = input_partition_dims[0] self._label_dims = None assert self._feature_dims is not None, ('input_partition_dims[0] must ' 'not be None') else: self._feature_dims = None self._label_dims = None # Internal state. self._initialized = False @property def flattened_input_dims(self): assert self._initialized, 'InputsStructureRecorder is not initialized.' return self._flattened_input_dims def has_labels(self): return 'labels' in self._feature_structure def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, label_dims_names, label_names, has_labels): """Flatten input dims with the same order as flattened input tensors.""" flattened_input_dims = [] if feature_dims_names: # We need a fixed ordering for matching the tensors in features. flattened_input_dims.extend( [feature_dims[name] for name in feature_dims_names]) else: flattened_input_dims.append(feature_dims) if label_dims_names: # We need a fixed ordering for matching the tensors in labels. flattened_input_dims.extend( [label_dims[name] for name in label_dims_names]) else: if label_names: num_tensors_in_label = len(label_names) else: num_tensors_in_label = int(has_labels) # Setting `None` in input_partition_dims[1] will apply `None` to # all the tensors in labels, regardless of internal structure. flattened_input_dims.extend([label_dims] * num_tensors_in_label) return flattened_input_dims def validate_and_record_structure(self, features, labels): """Validates and records the structure of `features` and `labels`.""" # Extract structure. has_labels = labels is not None feature_names = _extract_key_names(features) label_names = _extract_key_names(labels) if not self._initialized: # Record structure. self._initialized = True if self._feature_dims is not None: feature_dims_names = _extract_key_names(self._feature_dims) if feature_dims_names != feature_names: raise ValueError( 'TPUConfig.input_partition_dims[0] mismatched feature' ' keys. Expected {}, got {}'.format(feature_names, feature_dims_names)) label_dims_names = _extract_key_names(self._label_dims) if self._label_dims is not None and label_dims_names != label_names: raise ValueError( 'TPUConfig.input_partition_dims[1] mismatched label' ' keys. Expected {}, got {}'.format(label_names, label_dims_names)) self._flattened_input_dims = self._flatten_input_dims( self._feature_dims, feature_dims_names, self._label_dims, label_dims_names, label_names, has_labels) def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" self._feature_structure['features'] = features if labels is not None: self._feature_structure['labels'] = labels if signals is not None: self._feature_structure['signals'] = signals return data_nest.flatten(self._feature_structure) def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. Args: flattened_inputs: Flattened inputs for each shard. Returns: A tuple of (`features`, `labels`), where `labels` could be None. Each one, if present, should have identical structure (single tensor vs dict) as the one returned by input_fn. Raises: ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, flattened_inputs) return _Inputs( unflattened_inputs['features'], unflattened_inputs.get('labels'), signals=unflattened_inputs.get('signals')) def __init__(self, input_fn, batch_axis, ctx): """Constructor. Args: input_fn: input fn for train or eval. batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. ctx: A `_InternalTPUContext` instance with mode. Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. """ self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( ctx.input_partition_dims) self._sharded_per_core = ctx.is_input_sharded_per_core() self._input_fn = input_fn self._infeed_queue = None self._ctx = ctx self._batch_axis = batch_axis def generate_infeed_enqueue_ops_and_dequeue_fn(self): """Generates infeed enqueue ops and dequeue_fn.""" # While tf.while_loop is called, the body function, which invokes # `enqueue_fn` passed in, is called to construct the graph. So, input_fn # structure is recorded. enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( self._invoke_input_fn_and_record_structure()) self._validate_input_pipeline() def dequeue_fn(): """dequeue_fn is used by TPU to retrieve the tensors.""" # In the model-parallel case, both the host-side and DEVICE-side # computations must agree on the core on which infeed takes place. We # choose to perform infeed on logical core 0 of each replica. values = self._infeed_queue.generate_dequeue_op(tpu_device=0) # The unflatten process uses the structure information recorded above. return self._inputs_structure_recorder.unflatten_features_and_labels( values) return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) def _invoke_input_fn_and_record_structure(self): """Deploys the input pipeline and record input structure.""" enqueue_ops = [] infeed_queues = [] all_hooks = [] num_hosts = self._ctx.num_hosts tpu_host_placement_fn = self._ctx.tpu_host_placement_function run_infeed_loop_on_coordinator = True if self._sharded_per_core: # Per-Core input pipeline deployment. # Invoke input pipeline for each core and placed on the corresponding # host. for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) with ops.DEVICE(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): enqueue_ops_fn, captured_infeed_queue = ( generate_per_core_enqueue_ops_fn_for_host( self._ctx, self._input_fn, self._inputs_structure_recorder, host_device, host_id)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: run_infeed_loop_on_coordinator = False enqueue_ops.append( _wrap_computation_in_while_loop( device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(captured_infeed_queue.get()) elif self._ctx.is_input_broadcast_with_iterators(): # Only calls input_fn in host 0. host_device = tpu_host_placement_fn(host_id=0) enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, self._inputs_structure_recorder, num_hosts)) all_hooks.extend(hooks) if is_dataset: run_infeed_loop_on_coordinator = False wrap_fn = ( _wrap_computation_in_while_loop if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else _wrap_computation_in_while_loop_with_stopping_signals) enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) else: for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) with ops.DEVICE(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): if self._ctx.is_input_per_host_with_iterators(): enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( generate_per_host_v2_enqueue_ops_fn_for_host( self._ctx, self._input_fn, self._inputs_structure_recorder, host_device, host_id)) else: enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( generate_per_host_enqueue_ops_fn_for_host( self._ctx, self._input_fn, self._inputs_structure_recorder, self._batch_axis, host_device, host_id)) all_hooks.extend(hooks) # NOTE(xiejw): We dispatch here based on the return type of the # users `input_fn`. # # 1. If input_fn returns a Dataset instance, we initialize the # iterator outside of tf.while_loop, and call the iterator.get_next # inside tf.while_loop. This should be always safe. # # 2. If input_fn returns (features, labels), it is too late to wrap # them inside tf.while_loop, as resource initialization cannot be # handled in TF control flow properly. In this case, we will use # python loop to enqueue the data into TPU system. This may be # slow compared to the previous case. if is_dataset: run_infeed_loop_on_coordinator = False wrap_fn = ( _wrap_computation_in_while_loop if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else _wrap_computation_in_while_loop_with_stopping_signals) enqueue_ops.append( wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) # infeed_queue is used to generate dequeue ops. The only thing it uses for # dequeue is dtypes and types. So, any one can be used. Here, grab the # first one. self._infeed_queue = infeed_queues[0] return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): """Validates the input pipeline. Perform some sanity checks to log user friendly information. We should error out to give users better error message. But, if _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break user code, so, log a warning. Raises: RuntimeError: If the validation failed. """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' 'converting your input pipeline to use `tf.data` instead (see ' 'https://www.tensorflow.org/guide/datasets for ' 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: raise RuntimeError(err_msg) else: logging.warn(err_msg) class _ModelFnWrapper(object): """A `model_fn` wrapper. This makes calling model_fn on CPU and TPU easier and more consistent and performs necessary check and mutation required by TPU training and evaluation. In addition, this wrapper manages converting the `model_fn` to a single TPU train and eval step. """ def __init__(self, model_fn, train_cache_fn, eval_cache_fn, config, params, ctx): self._model_fn = model_fn self._train_cache_fn = train_cache_fn self._eval_cache_fn = eval_cache_fn self._config = config self._params = params self._ctx = ctx def call_without_tpu(self, features, labels, is_export_mode): return self._call_model_fn(features, labels, is_export_mode=is_export_mode) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. The user provided `model_fn` takes input tuple (features, labels) and produces the EstimatorSpec with train_op and loss for train `mode`. This usually represents a single train computation on CPU. For TPU training, a train (computation) step is first wrapped in a tf.while_loop control flow to repeat for many times and then replicated to all TPU shards. Besides the input should be taken from TPU infeed rather than input pipeline (input_fn) directly. To fit TPU loop and replicate pattern, the original train computation should be reformed, which is the returned `train_step`. Args: dequeue_fn: The function to retrieve inputs, features and labels, from TPU infeed dequeue channel. Returns: A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn representing the train step for TPU. """ host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() captured_training_hooks = _CapturedObject() def train_step(loss, *cache): """Training step function for use inside a while loop.""" if not self._params.get('track_mean', False): del loss # unused; required in function signature. inputs = dequeue_fn() features, labels = inputs.features_and_labels() # Consume the current cache estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels, cache=cache)) # Retrieve the new returned cache """ `cache` consists of a list of tensors, potentially empty (of length 0) """ cache = estimator_spec.cache new_loss, train_op = estimator_spec.loss, estimator_spec.train_op if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access captured_scaffold_fn.capture(estimator_spec.scaffold_fn) else: captured_scaffold_fn.capture(None) captured_training_hooks.capture(estimator_spec.training_hooks) # We must run train_op to update the variables prior to running the # outfeed. with ops.control_dependencies([train_op]): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): host_call.record({'host_call': estimator_spec.host_call}) host_call_outfeed_ops = host_call.create_enqueue_op() with ops.control_dependencies(host_call_outfeed_ops): if self._params.get('track_mean', False): loss = tensorflow.stop_gradient(loss) return [math_ops.add(loss, new_loss)] + cache else: return [array_ops.identity(new_loss)] + cache return (train_step, host_call, captured_scaffold_fn, captured_training_hooks) def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. Similar to training, the user provided `model_fn` takes input tuple (features, labels) and produces the TPUEstimatorSpec with eval_metrics for eval `mode`. This usually represents a single evaluation computation on CPU. For TPU evaluation, a eval (computation) step is first wrapped in a tf.while_loop control flow to repeat for many times and then replicated to all TPU shards. Besides the input and output are slightly different. Input, features and labels, should be taken from TPU infeed rather than input pipeline (input_fn) directly. Output is managed in two stages. First, the model outputs as the result of evaluation computation, usually model logits, should be transferred from TPU system to CPU. Then, all model outputs are concatenated first on CPU and sent to the metric_fn for metrics computation. To fit TPU evaluation pattern, the original eval computation should be reformed, which is the returned `eval_step`. Args: dequeue_fn: The function to retrieve inputs, features and labels, from TPU infeed dequeue channel. Returns: A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn representing the eval step for TPU. """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() captured_eval_hooks = _CapturedObject() def eval_step(total_loss, *cache): """Evaluation step function for use inside a while loop.""" inputs = dequeue_fn() features, labels = inputs.features_and_labels() # Consume the current cache tpu_estimator_spec = self._call_model_fn(features, labels, cache=cache) if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) # Retrieve the new returned cache cache = tpu_estimator_spec.cache loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) to_record = {} if tpu_estimator_spec.eval_metrics: to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics if tpu_estimator_spec.host_call is not None: # We assume that evaluate won't update global step, so we don't wrap # this host_call. to_record['host_call'] = tpu_estimator_spec.host_call host_calls.record(to_record) with ops.control_dependencies(host_calls.create_enqueue_op()): return [math_ops.add(total_loss, loss)] + cache return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks def convert_to_single_tpu_predict_step(self, dequeue_fn): """Converts user provided model_fn` as a single predict step on TPU. Args: dequeue_fn: The function to retrieve inputs, features and labels, from TPU infeed dequeue channel. Returns: A tuple of predict_fn, host_calls, and captured scaffold_fn. The predict_fn representing the predict step for TPU. """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() captured_predict_hooks = _CapturedObject() def predict_step(unused_scalar_stopping_signal): """Evaluation step function for use inside a while loop.""" inputs = dequeue_fn() features, labels = inputs.features_and_labels() stopping_signals = inputs.signals() assert stopping_signals is not None, ( 'Internal Error: `signals` is missing.') tpu_estimator_spec = self._call_model_fn( features, labels, is_export_mode=False) if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access raise RuntimeError( 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) to_record = {} identity_fn = lambda **kwargs: kwargs to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] to_record['signals'] = [identity_fn, stopping_signals] if tpu_estimator_spec.host_call is not None: to_record['host_call'] = tpu_estimator_spec.host_call host_calls.record(to_record) with ops.control_dependencies(host_calls.create_enqueue_op()): return _StopSignals.as_scalar_stopping_signal(stopping_signals) return (predict_step, host_calls, captured_scaffold_fn, captured_predict_hooks) def _verify_tpu_spec_predictions(self, predictions): """Validates TPUEstimatorSpec.predictions dict.""" # TODO(xiejw): Adds validation for prediction dictionrary. # TODO(xiejw): Adds support for single tensor as predictions. if not isinstance(predictions, dict): raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') for (key, tensor) in predictions.items(): if tensor.shape[0].value is None: raise ValueError( 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' 'dynamic shape (should be static). Tensor: {}'.format( key, tensor)) return predictions def _validate_model_features_and_labels(self, features, labels, is_export_mode): """Validates that the features and labels for the model function are valid. A valid features/labels object is the one with: - Type: Tensor or a dictionary of Tensors - Static shape if is_export_mode is False. Args: features: the features that would be input to the model function. labels: the labels that would be input to the model function. is_export_mode: boolean value specifying if in export mode. Raises: TypeError: If features/labels are not of the correct type. ValueError: If features/labels have dynamic shape. """ def validate(obj, obj_name): """Helper validate function.""" if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): raise TypeError( 'The {} to the model returned by input_fn must be either a Tensor ' 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, obj)) if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): return if isinstance(obj, ops.Tensor): if not obj.get_shape().is_fully_defined(): raise ValueError( 'The {} to the model returned by input_fn must have static shape.' ' Tensor: {}'.format(obj_name, obj)) else: for (key, value) in obj.items(): flattened_tensors = data_nest.flatten(value) for tensor in flattened_tensors: if not tensor.get_shape().is_fully_defined(): raise ValueError( 'The {} to the model returned by input_fn must have static ' 'shape. Key: \'{}\', Tensor: {}'.format( obj_name, key, tensor)) validate(features, 'features') if labels is not None: validate(labels, 'labels') def _call_model_fn(self, features, labels, cache=None, is_export_mode=False): """Calls the model_fn with required parameters.""" self._validate_model_features_and_labels(features, labels, is_export_mode) model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} # Makes deep copy with `config` and params` in case user mutates them. config = copy.deepcopy(self._config) params = copy.deepcopy(self._params) if 'labels' in model_fn_args: kwargs['labels'] = labels elif labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = self._ctx.mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: kwargs['params'] = params if cache is not None: params['cache'] = cache if 'params' not in model_fn_args: raise ValueError('model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) if is_export_mode: batch_size_for_model_fn = None else: batch_size_for_model_fn = self._ctx.batch_size_for_model_fn if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) if not running_on_cpu: user_context = tpu_context.TPUContext( internal_ctx=self._ctx, call_from_input_fn=False) _add_item_to_params(params, _CTX_KEY, user_context) estimator_spec = self._model_fn(features=features, **kwargs) if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() else: return estimator_spec def _verify_estimator_spec(self, estimator_spec): """Validates the estimator_spec.""" if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access return estimator_spec err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' if estimator_spec.training_chief_hooks: raise ValueError( err_msg.format('training_chief_hooks') + 'If you want' + ' to pass training hooks, please pass via training_hooks.') if estimator_spec.scaffold: logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' 'Please use TPUEstimatorSpec.') return estimator_spec class _OutfeedHostCall(object): """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" def __init__(self, ctx): self._ctx = ctx self._names = [] # All of these are dictionaries of lists keyed on the name. self._host_fns = {} self._tensor_keys = collections.defaultdict(list) self._tensors = collections.defaultdict(list) self._tensor_dtypes = collections.defaultdict(list) self._tensor_shapes = collections.defaultdict(list) @staticmethod def validate(host_calls): """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" for name, host_call in host_calls.items(): if not isinstance(host_call, (tuple, list)): raise ValueError('{} should be tuple or list'.format(name)) if len(host_call) != 2: raise ValueError('{} should have two elements.'.format(name)) if not callable(host_call[0]): raise TypeError('{}[0] should be callable.'.format(name)) if not isinstance(host_call[1], (tuple, list, dict)): raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) if isinstance(host_call[1], (tuple, list)): fullargspec = tf_inspect.getfullargspec(host_call[0]) fn_args = function_utils.fn_args(host_call[0]) # wrapped_hostcall_with_global_step uses varargs, so we allow that. if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): raise RuntimeError( 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' 'method args of the function, which takes {}.'.format( name, len(host_call[1]), len(fn_args))) @staticmethod def create_cpu_hostcall(host_calls): """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" _OutfeedHostCall.validate(host_calls) ret = {} for name, host_call in host_calls.items(): host_fn, tensors = host_call if isinstance(tensors, (tuple, list)): ret[name] = host_fn(*tensors) else: # Must be dict. try: ret[name] = host_fn(**tensors) except TypeError as e: logging.warning( 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) raise e return ret def record(self, host_calls): """Records the host_call structure.""" for name, host_call in host_calls.items(): host_fn, tensor_list_or_dict = host_call self._names.append(name) self._host_fns[name] = host_fn if isinstance(tensor_list_or_dict, dict): for (key, tensor) in six.iteritems(tensor_list_or_dict): self._tensor_keys[name].append(key) self._tensors[name].append(tensor) self._tensor_dtypes[name].append(tensor.dtype) self._tensor_shapes[name].append(tensor.shape) else: # List or tuple. self._tensor_keys[name] = None for tensor in tensor_list_or_dict: self._tensors[name].append(tensor) self._tensor_dtypes[name].append(tensor.dtype) self._tensor_shapes[name].append(tensor.shape) def create_enqueue_op(self): """Create the op to enqueue the recorded host_calls. Returns: A list of enqueue ops, which is empty if there are no host calls. """ if not self._names: return [] tensors = [] # TODO(jhseu): Consider deduping tensors. for name in self._names: tensors.extend(self._tensors[name]) with ops.DEVICE(tpu.core(0)): return [tpu_ops.outfeed_enqueue_tuple(tensors)] def create_tpu_hostcall(self): """Sends the tensors through outfeed and runs the host_fn on CPU. The tensors are concatenated along dimension 0 to form a global tensor across all shards. The concatenated function is passed to the host_fn and executed on the first host. Returns: A dictionary mapping name to the return type of the host_call by that name. Raises: RuntimeError: If outfeed tensor is scalar. """ if not self._names: return {} ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for name in self._names: for _ in self._tensors[name]: dequeue_ops.append([]) for dtype in self._tensor_dtypes[name]: tensor_dtypes.append(dtype) for shape in self._tensor_shapes[name]: tensor_shapes.append(shape) # Outfeed ops execute on each replica's first logical core. Note: we must # constraint it such that we have at most one outfeed dequeue and enqueue # per replica. for i in xrange(self._ctx.num_replicas): host_device, ordinal_id = self._ctx.device_for_replica(i) with ops.DEVICE(host_device): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal_id) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) # Deconstruct dequeue ops. dequeue_ops_by_name = {} pos = 0 for name in self._names: dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])] pos += len(self._tensors[name]) # It is assumed evaluation always happens on single host TPU system. So, # place all ops on tpu host if possible. # # TODO(jhseu): Evaluate whether this is right for summaries. with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)): for name in self._names: dequeue_ops = dequeue_ops_by_name[name] for i, item in enumerate(dequeue_ops): if dequeue_ops[i][0].shape.ndims == 0: raise RuntimeError( 'All tensors outfed from TPU should preserve batch size ' 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) # TODO(xiejw): Allow users to specify the axis for batch size # dimension. dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) if self._tensor_keys[name] is not None: # The user-provided eval_metrics[1] is a dict. dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) try: ret[name] = self._host_fns[name](**dequeue_ops) except TypeError as e: logging.warning( 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) raise e else: ret[name] = self._host_fns[name](*dequeue_ops) return ret class _OutfeedHostCallHook(session_run_hook.SessionRunHook): """Hook to run host calls when use_tpu=False.""" def __init__(self, tensors): self._tensors = tensors def begin(self): # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than # create a separate hook to guarantee execution order, because summaries # need to be initialized before the outfeed thread starts. # TODO(jhseu): Make a wrapper hook instead? self._init_ops = contrib_summary.summary_writer_initializer_op() # Get all the writer resources from the initializer, so we know what to # flush. self._finalize_ops = [] for op in self._init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) def after_create_session(self, session, coord): session.run(self._init_ops) def before_run(self, run_context): return basic_session_run_hooks.SessionRunArgs(self._tensors) def end(self, session): session.run(self._finalize_ops) class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): """Calculate and report global_step/sec and examples/sec during runtime.""" def __init__(self, batch_size, every_n_steps=100, every_n_secs=None, output_dir=None, summary_writer=None): self._batch_size = batch_size super(ExamplesPerSecondHook, self).__init__( every_n_steps=every_n_steps, every_n_secs=every_n_secs, output_dir=output_dir, summary_writer=summary_writer) def _log_and_record(self, elapsed_steps, elapsed_time, global_step): global_step_per_sec = elapsed_steps / elapsed_time examples_per_sec = self._batch_size * global_step_per_sec if self._summary_writer is not None: global_step_summary = Summary(value=[ Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) ]) example_summary = Summary(value=[ Summary.Value(tag='examples/sec', simple_value=examples_per_sec) ]) self._summary_writer.add_summary(global_step_summary, global_step) self._summary_writer.add_summary(example_summary, global_step) logging.info('global_step/sec: %g', global_step_per_sec) logging.info('examples/sec: %g', examples_per_sec) class InstallSignalHandlerHook(session_run_hook.SessionRunHook): """Change SIGINT (CTRL^C) handler to force quit the process. The default behavior often results in hanging processes. The original handler is restored after training/evaluation. """ def __init__(self): self._signal_fn = signal.getsignal(signal.SIGINT) def before_run(self, run_context): signal.signal(signal.SIGINT, signal.SIG_DFL) def end(self, session): signal.signal(signal.SIGINT, self._signal_fn) class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. TPUEstimator also supports training on CPU and GPU. You don't need to define a separate `tf.estimator.Estimator`. TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. TPUEstimator transforms a global batch size in params to a per-shard batch size when calling the `input_fn` and `model_fn`. Users should specify global batch size in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. - For training, `model_fn` gets per-core batch size; `input_fn` may get per-core or per-host batch size depending on `per_host_input_for_training` in `TPUConfig` (See docstring for TPUConfig for details). - For evaluation and prediction, `model_fn` gets per-core batch size and `input_fn` get per-host batch size. Evaluation ========== `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case the following discussion on TPU evaluation does not apply. `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. One can set `use_tpu` to `False` for testing. All training, evaluation, and predict will be executed on CPU. `input_fn` and `model_fn` will receive `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. Current limitations: -------------------- 1. TPU evaluation only works on a single host (one TPU worker) except BROADCAST mode. 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all batches should have the same size. Example (MNIST): ---------------- ``` # The metric Fn which runs on CPU. def metric_fn(labels, logits): predictions = tf.argmax(logits, 1) return { 'accuracy': tf.metrics.precision( labels=labels, predictions=predictions), } # Your model Fn which runs on TPU (eval_metrics is list in this example) def model_fn(features, labels, mode, config, params): ... logits = ... if mode = tf.estimator.ModeKeys.EVAL: return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) # or specify the eval_metrics tensors as dict. def model_fn(features, labels, mode, config, params): ... final_layer_output = ... if mode = tf.estimator.ModeKeys.EVAL: return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(metric_fn, { 'labels': labels, 'logits': final_layer_output, })) ``` Prediction ========== Prediction on TPU is an experimental feature to support large batch inference. It is not designed for latency-critical system. In addition, due to some usability issues, for prediction with small dataset, CPU `.predict`, i.e., creating a new `TPUEstimator` instance with `use_tpu=False`, might be more convenient. Note: In contrast to TPU training/evaluation, the `input_fn` for prediction *should* raise an end-of-input exception (`OutOfRangeError` or `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be precise, the ops created by `input_fn` produce one batch of the data. The `predict()` API processes one batch at a time. When reaching the end of the data source, an end-of-input exception should be raised by one of these operations. The user usually does not need to do this manually. As long as the dataset is not repeated forever, the `tf.data` API will raise an end-of-input exception automatically after the last batch has been produced. Note: Estimator.predict returns a Python generator. Please consume all the data from the generator so that TPUEstimator can shutdown the TPU system properly for user. Current limitations: -------------------- 1. TPU prediction only works on a single host (one TPU worker). 2. `input_fn` must return a `Dataset` instance rather than `features`. In fact, .train() and .evaluate() also support Dataset as return value. Example (MNIST): ---------------- ``` height = 32 width = 32 total_examples = 100 def predict_input_fn(params): batch_size = params['batch_size'] images = tf.random_uniform( [total_examples, height, width, 3], minval=-1, maxval=1) dataset = tf.data.Dataset.from_tensor_slices(images) dataset = dataset.map(lambda images: {'image': images}) dataset = dataset.batch(batch_size) return dataset def model_fn(features, labels, params, mode): # Generate predictions, called 'output', from features['image'] if mode == tf.estimator.ModeKeys.PREDICT: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={ 'predictions': output, 'is_padding': features['is_padding'] }) tpu_est = TPUEstimator( model_fn=model_fn, ..., predict_batch_size=16) # Fully consume the generator so that TPUEstimator can shutdown the TPU # system. for item in tpu_est.predict(input_fn=input_fn): # Filter out item if the `is_padding` is 1. # Process the 'predictions' ``` Exporting ========= `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, and another with `tag_constants.SERVING` and `tag_constants.TPU`. At serving time, these tags are used to select metagraph to load. Before running the graph on TPU, TPU system needs to be initialized. If TensorFlow Serving model-server is used, this is done automatically. If not, please call `session.run(tpu.initialize_system())`. `tpu.outside_compilation` can be used to wrap TPU incompatible ops in `model_fn`. Example: ---------------- ``` def model_fn(features, labels, mode, config, params): ... logits = ... export_outputs = { 'logits': export_output_lib.PredictOutput( {'logits': logits}) } def host_call(logits): class_ids = math_ops.argmax(logits) classes = string_ops.as_string(class_ids) export_outputs['classes'] = export_output_lib.ClassificationOutput(classes=classes) tpu.outside_compilation(host_call, logits) ... ``` """ def __init__(self, model_fn=None, train_cache_fn=None, eval_cache_fn=None, model_dir=None, config=None, params=None, use_tpu=True, train_batch_size=None, eval_batch_size=None, predict_batch_size=None, batch_axis=None, eval_on_tpu=True, export_to_tpu=True, warm_start_from=None): """Constructs an `TPUEstimator` instance. Args: model_fn: Model function as required by `Estimator` which returns EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', and `prediction_hooks` must not capure any TPU Tensor inside the model_fn. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in `config` will be used if set. If both are set, they must be same. If both are `None`, a temporary directory will be used. config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. params: An optional `dict` of hyper parameters that will be passed into `input_fn` and `model_fn`. Keys are names of parameters, values are basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. use_tpu: A bool indicating whether TPU support is enabled. Currently, - TPU training and evaluation respect this bit, but eval_on_tpu can override execution of eval. See below. - Predict still happens on CPU. train_batch_size: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be divisible by total number of replicas. eval_batch_size: An int representing evaluation batch size. Must be divisible by total number of replicas. predict_batch_size: An int representing the prediction batch size. Must be divisible by total number of replicas. batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) where the images tensor is in `HWCN` format, your shard dimensions would be [3, 0], where 3 corresponds to the `N` dimension of your images Tensor, and 0 corresponds to the dimension along which to split the labels to match up with the corresponding images. If None is supplied, and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: If True, `export_savedmodel()` exports a metagraph for serving on TPU besides the one on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. Raises: ValueError: `params` has reserved keys already. """ if config is None or not isinstance(config, tpu_config.RunConfig): raise ValueError( '`config` must be provided with type `tpu_config.RunConfig`') if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): raise ValueError('{} are reserved keys but existed in params {}.'.format( _RESERVED_PARAMS_KEYS, params)) if use_tpu: # Perform some very basic validations. More validations will be found in # _InternalTPUContext. if train_batch_size is None: raise ValueError('`train_batch_size` cannot be `None`') util_lib.check_positive_integer(train_batch_size, 'train_batch_size') if (config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_SHARD_V1 and config.tpu_config.num_cores_per_replica): raise ValueError( 'Model parallelism only supports per host input for training. ' 'Please adjust TPURunconfig.per_host_input_for_training.') if eval_batch_size is not None: util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') if predict_batch_size is not None: util_lib.check_positive_integer(predict_batch_size, 'predict_batch_size') # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. model_function = self._augment_model_fn( model_fn, train_cache_fn, eval_cache_fn, batch_axis) # Overwrite log_step_count_steps to disable TensorLoggingHook and # StepCounterHook from being created in Estimator. TPUEstimator already # added equivalent hooks in _augment_model_fn above. self._log_every_n_steps = config.log_step_count_steps config = config.replace(log_step_count_steps=None) # Passing non-None params as wrapped model_fn has it. params = params or {} super(TPUEstimator, self).__init__( model_fn=model_function, model_dir=model_dir, config=config, params=params, warm_start_from=warm_start_from) self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) # All properties passed to _InternalTPUContext are immutable. # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu, eval_on_tpu) self._export_to_tpu = export_to_tpu self._is_input_fn_invoked = None self._rendezvous = {} def _add_meta_graph_for_mode(self, builder, input_receiver_fn_map, checkpoint_path, strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, export_tags=None, check_variables=True): if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: raise NotImplementedError( 'TPUEstimator only handles mode PREDICT for exporting ' 'when `export_to_tpu` is `True`; ' 'got {}.'.format(mode)) (super(TPUEstimator, self). _add_meta_graph_for_mode(builder, input_receiver_fn_map, checkpoint_path, strip_default_attrs, save_variables, mode=mode, export_tags=export_tags, check_variables=check_variables)) if self._export_to_tpu: input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode]} export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE # See b/110052256 for why `check_variables` is `False`. (super(TPUEstimator, self). _add_meta_graph_for_mode(builder, input_receiver_fn_map, checkpoint_path, strip_default_attrs, save_variables=False, mode=mode, export_tags=export_tags, check_variables=False)) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: return self._call_model_fn_for_inference(features, labels, mode, config) else: return super(TPUEstimator, self)._call_model_fn( features, labels, mode, config) def _call_model_fn_for_inference(self, features, labels, mode, config): """Wraps `_call_model_fn` for `export_savedmodel`.""" if mode != _REWRITE_FOR_INFERENCE_MODE: raise ValueError('mode must be {}; ' 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) capture = _CapturedObject() def computation(): """Compute tpu tensors used in export_outputs. Passed to rewrite_for_inference so that model_fn will be called under the rewriting contexts. Only tpu tensors are returned, but export_outputs and scaffold are captured. Returns: A list of Tensors used in export_outputs and not marked for outside_compilation. """ # We should only call model fn once and it should be inside `computation` # so that building the graph will happen under `rewrite_for_inference`. mode = model_fn_lib.ModeKeys.PREDICT estimator_spec = self._call_model_fn(features, labels, mode, config) # We pick the TPU tensors out from `export_output` and later return them # from `computation` for rewriting. tensors_dict = collections.OrderedDict( (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs) ) tensors = nest.flatten(tensors_dict) tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] # We cannot return anything other than `tpu_tensors` here so we capture # the rest for later use. capture.capture((estimator_spec, tensors_dict, tensors)) return tpu_tensors tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) estimator_spec, tensors_dict, tensors = capture.get() # Reconstruct `tensors`, but with `tpu_tensors` replaced with # `tpu_tensors_on_cpu`. new_tensors = [] for t in tensors: if _is_tpu_tensor(t): new_tensors.append(tpu_tensors_on_cpu.pop(0)) elif t is None: new_tensors.append(None) else: # Only fetching `tpu_tensors_on_cpu` does not trigger # TPU computation and blocks, so we add the control dependency here. control_inputs = (tpu_tensors_on_cpu if isinstance(tpu_tensors_on_cpu, (list, tuple)) else (tpu_tensors_on_cpu,)) with ops.control_dependencies(control_inputs): new_tensors.append(array_ops.identity(t)) # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) # Reconstruct `export_outputs`. export_outputs = estimator_spec.export_outputs new_export_outputs = collections.OrderedDict( (k, _clone_export_output_with_tensors(export_outputs[k], v)) for k, v in six.iteritems(new_tensors_dict) ) return estimator_spec._replace(export_outputs=new_export_outputs) def _create_global_step(self, graph): """Creates a global step suitable for TPUs. Args: graph: The graph in which to create the global step. Returns: A global step `Tensor`. Raises: ValueError: if the global step tensor is already defined. """ return _create_global_step(graph) def _convert_train_steps_to_hooks(self, steps, max_steps): with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: if ctx.is_running_on_cpu(): return super(TPUEstimator, self)._convert_train_steps_to_hooks( steps, max_steps) # On TPU. if steps is None and max_steps is None: raise ValueError( 'For TPU training, one of `steps` or `max_steps` must be set. ' 'Cannot be both `None`.') # Estimator.train has explicit positiveness check. if steps is not None: util_lib.check_positive_integer(steps, 'Train steps') if max_steps is not None: util_lib.check_positive_integer(max_steps, 'Train max_steps') return [ _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) ] def _convert_eval_steps_to_hooks(self, steps): with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: if ctx.is_running_on_cpu(): return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') util_lib.check_positive_integer(steps, 'Eval steps') return [ evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access num_evals=steps), _SetEvalIterationsHook(steps) ] def _call_input_fn(self, input_fn, mode): """Calls the input function. Args: input_fn: The input function. mode: ModeKeys Returns: Either features or (features, labels) where features and labels are: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. Raises: ValueError: if input_fn takes invalid arguments or does not have `params`. """ input_fn_args = function_utils.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: kwargs['params'] = self.params # a deep copy. else: raise ValueError('input_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params["batch_size"]'.format(input_fn)) if 'config' in input_fn_args: kwargs['config'] = config if 'mode' in input_fn_args: kwargs['mode'] = mode # Records the fact input_fn has been invoked. self._is_input_fn_invoked = True with self._ctx.with_mode(mode) as ctx: # Setting the batch size in params first. This helps user to have same # input_fn for use_tpu=True/False. batch_size_for_input_fn = ctx.batch_size_for_input_fn if batch_size_for_input_fn is not None: _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY, batch_size_for_input_fn) # For export_savedmodel, input_fn is never passed to Estimator. So, # `is_export_mode` must be False. if ctx.is_running_on_cpu(is_export_mode=False): with ops.DEVICE('/DEVICE:CPU:0'): return input_fn(**kwargs) # For TPU computation, input_fn should be invoked in a tf.while_loop for # performance. While constructing the tf.while_loop, the structure of # inputs returned by the `input_fn` needs to be recorded. The structure # includes whether features or labels is dict or single Tensor, dict keys, # tensor shapes, and dtypes. The recorded structure is used to create the # infeed dequeue ops, which must be wrapped and passed as a Fn, called # inside the TPU computation, as the TPU computation is wrapped inside a # tf.while_loop also. So, we either pass input_fn to model_fn or pass # dequeue_fn to model_fn. Here, `input_fn` is passed directly as # `features` in `model_fn` signature. def _input_fn(ctx): _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) return input_fn(**kwargs) return _input_fn def _validate_features_in_predict_input(self, result): """Skip the validation. For TPUEstimator, we do not need to check the result type. `_InputPipeline` has stronger check. Parent class's check generates confusing warning msg. Args: result: `features` returned by input_fn. """ pass def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None): rendezvous = error_handling.ErrorRendezvous(num_sources=3) self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous try: return super(TPUEstimator, self).train( input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, saving_listeners=saving_listeners ) except Exception: # pylint: disable=broad-except rendezvous.record_error('training_loop', sys.exc_info()) finally: rendezvous.record_done('training_loop') rendezvous.raise_errors() def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None): rendezvous = error_handling.ErrorRendezvous(num_sources=3) self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous try: return super(TPUEstimator, self).evaluate( input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, name=name ) except Exception: # pylint: disable=broad-except rendezvous.record_error('evaluation_loop', sys.exc_info()) finally: rendezvous.record_done('evaluation_loop') rendezvous.raise_errors() def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True): rendezvous = error_handling.ErrorRendezvous(num_sources=3) self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous try: for result in super(TPUEstimator, self).predict( input_fn=input_fn, predict_keys=predict_keys, hooks=hooks, checkpoint_path=checkpoint_path, yield_single_examples=yield_single_examples): yield result except Exception: # pylint: disable=broad-except rendezvous.record_error('prediction_loop', sys.exc_info()) finally: rendezvous.record_done('prediction_loop') rendezvous.raise_errors() rendezvous.record_done('prediction_loop') rendezvous.raise_errors() def _augment_model_fn(self, model_fn, train_cache_fn, eval_cache_fn, batch_axis): """Returns a new model_fn, which wraps the TPU support.""" def _model_fn(features, labels, mode, config, params): """A Estimator `model_fn` for TPUEstimator.""" with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, train_cache_fn, eval_cache_fn, config, params, ctx) # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, # but not in `export_savedmodel()`. if self._is_input_fn_invoked: is_export_mode = False else: is_export_mode = True # Clear the bit. self._is_input_fn_invoked = None # examples_hook is added to training_hooks for both CPU and TPU # execution. examples_hook = ExamplesPerSecondHook( ctx.global_batch_size, output_dir=self.model_dir, every_n_steps=self._log_every_n_steps) if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) estimator_spec = model_fn_wrapper.call_without_tpu( features, labels, is_export_mode=is_export_mode) estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + (examples_hook,)) return estimator_spec assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. assert callable(features), '`input_fn` is not callable.' input_fn = features input_holders = _InputPipeline(input_fn, batch_axis, ctx) enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) graph = ops.get_default_graph() for enqueue_op in enqueue_ops: if isinstance(enqueue_op, list): graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) else: graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) if model_fn_wrapper._params.get('track_mean', False): iterations_per_loop_var = _create_or_get_iterations_per_loop() loss = math_ops.div(loss, math_ops.cast( iterations_per_loop_var, dtype=loss.dtype)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: host_ops = [] shutdown_hooks = [] shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', 'shutdown_worker') if shutdown_mode: if shutdown_mode == 'shutdown_worker': finalizer_hooks = [ session_support.ShutdownLameWorkers(timeout_ms=60*1000), ] elif shutdown_mode == 'shutdown_computation': finalizer_hooks = [ session_support.RestartComputation(timeout_ms=60*1000), ] else: raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode) shutdown_hooks.append(session_support.GracefulShutdownHook( checkpoint_prefix=self.model_dir + '/model.ckpt', on_shutdown_hooks=finalizer_hooks )) with ops.control_dependencies([loss]): global_step = array_ops.identity(training.get_global_step()) hooks = input_hooks + shutdown_hooks logging_hook_frequency = ( # Divide and round up (self._log_every_n_steps + self._config.tpu_config.iterations_per_loop - 1) // self._config.tpu_config.iterations_per_loop) iterations_per_loop = array_ops.identity( _create_or_get_iterations_per_loop()) hooks.extend([ TPUInfeedOutfeedSessionHook( ctx, enqueue_ops, host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], ), InstallSignalHandlerHook(), training.LoggingTensorHook( { 'loss': array_ops.identity(loss), 'ppl': tensorflow.exp(loss), 'bpc': loss / tensorflow.constant(math.log(2)), '#iter/loop': iterations_per_loop, 'global step': global_step, }, every_n_iter=logging_hook_frequency) ]) examples_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) if training_hooks: hooks.extend(training_hooks) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): checkpoint_hook = training.CheckpointSaverHook( self.model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) chief_hooks.append(checkpoint_hook) summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() # Validate the TPU training graph to catch basic errors _validate_tpu_training_graph() train_op = control_flow_ops.group(*update_ops) graph.add_to_collection(_TPU_TRAIN_OP, train_op) return model_fn_lib.EstimatorSpec( mode, loss=loss, training_chief_hooks=chief_hooks, training_hooks=hooks, train_op=train_op, scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div(total_loss, math_ops.cast( iterations_per_loop_var, dtype=total_loss.dtype)) # Creates a dummy metric update_op for all metrics. Estimator expects # all metrics in eval_metric_ops have update_op and calls them one by # one. The real metric update_ops are invoked in a separated thread. # So, here give Estimator the dummy op for all metrics. with ops.control_dependencies([mean_loss]): # After TPU evaluation computation is done (the mean_loss tensor), # reads all variables back from TPU and updates the eval step # counter properly internal_ops_to_run = _sync_variables_ops() internal_ops_to_run.append( _increase_eval_step_op(iterations_per_loop_var)) with ops.control_dependencies(internal_ops_to_run): dummy_update_op = control_flow_ops.no_op() host_call_ret = host_calls.create_tpu_hostcall() eval_metric_ops = {} eval_update_ops = [] for k, v in host_call_ret.get('eval_metrics', {}).items(): eval_metric_ops[k] = (v[0], dummy_update_op) eval_update_ops.append(v[1]) if 'host_call' not in host_call_ret: host_ops = [] else: host_ops = host_call_ret['host_call'] hooks = [ TPUInfeedOutfeedSessionHook( ctx, enqueue_ops, eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode]), ] + input_hooks if eval_hooks: hooks.extend(eval_hooks) return model_fn_lib.EstimatorSpec( mode, loss=mean_loss, evaluation_hooks=hooks, eval_metric_ops=eval_metric_ops, scaffold=scaffold) # Predict assert mode == model_fn_lib.ModeKeys.PREDICT (dummy_predict_op, host_calls, scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() with ops.control_dependencies(internal_ops_to_run): dummy_predict_op = control_flow_ops.no_op() # In train and evaluation, the main TPU program is passed to monitored # training session to run. Infeed enqueue and outfeed dequeue are # executed in side threads. This is not the configuration for # prediction mode. # # For prediction, the Estimator executes the EstimatorSpec.predictions # directly and yield the element (via generator) to call site. So, the # outfeed based prediction must be passed to MonitoredSession directly. # Other parts of the TPU execution are organized as follows. # # 1. All outfeed based Tensors must be grouped with predictions Tensors # to form a single invocation. This avoid the issue we might trigger # multiple outfeeds incorrectly. To achieve this, `host_call` is # placed in control_dependencies of `stopping_signals`, and # `stopping_signals` is passed into _StoppingPredictHook, which sets # the `stopping_signals` as SessionRunArgs. MonitoredSession merges # all SessionRunArgs with the fetch in session.run together. # # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) # are grouped together. They will be launched once and only once in # side threads and they quit naturally according to the SAME stopping # condition. enqueue_ops.append(dummy_predict_op) host_call_ret = host_calls.create_tpu_hostcall() if 'host_call' not in host_call_ret: host_ops = [] else: host_ops = host_call_ret['host_call'] predictions = host_call_ret['predictions'] _verify_cross_hosts_transfer_size( predictions, message=( 'The estimated size for TPUEstimatorSpec.predictions is too ' 'large.')) signals = host_call_ret['signals'] with ops.control_dependencies(host_ops): host_ops = [] # Empty, we do do not need it anymore. scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( signals) predictions = _PaddingSignals.slice_tensor_or_dict( predictions, signals) hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), ] + input_hooks if prediction_hooks: hooks.extend(prediction_hooks) return model_fn_lib.EstimatorSpec( mode, prediction_hooks=hooks, predictions=predictions, scaffold=scaffold) return _model_fn def _is_tpu_tensor(tensor): if not isinstance(tensor, ops.Tensor): return False try: tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access except ValueError: return True else: return False def _export_output_to_tensors(export_output): """Get a list of `Tensors` used in `export_output`. Args: export_output: an `ExportOutput` object such as `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. Returns: a list of tensors used in export_output. Raises: ValueError: if `export_output` is not one of `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. """ if isinstance(export_output, export_output_lib.ClassificationOutput): return [export_output.scores, export_output.classes] elif isinstance(export_output, export_output_lib.RegressionOutput): return [export_output.value] elif isinstance(export_output, export_output_lib.PredictOutput): return export_output.outputs.values() else: raise ValueError( '`export_output` must be have type `ClassificationOutput`, ' '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) def _clone_export_output_with_tensors(export_output, tensors): """Clones `export_output` but with new `tensors`. Args: export_output: an `ExportOutput` object such as `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. tensors: a list of `Tensors` used to construct a new `export_output`. Returns: A dict similar to `export_output` but with `tensors`. Raises: ValueError: if `export_output` is not one of `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. """ if isinstance(export_output, export_output_lib.ClassificationOutput): if len(tensors) != 2: raise ValueError('tensors must be of length 2; ' 'got {}.'.format(len(tensors))) return export_output_lib.ClassificationOutput(*tensors) elif isinstance(export_output, export_output_lib.RegressionOutput): if len(tensors) != 1: raise ValueError('tensors must be of length 1; ' 'got {}'.format(len(tensors))) return export_output_lib.RegressionOutput(*tensors) elif isinstance(export_output, export_output_lib.PredictOutput): return export_output_lib.PredictOutput( dict(zip(export_output.outputs.keys(), tensors))) else: raise ValueError( '`export_output` must be have type `ClassificationOutput`, ' '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) def multi_tpu_eval_steps_on_single_shard(): loop_vars = [_ZERO_LOSS] if model_fn_wrapper._eval_cache_fn is not None: batch_size = ctx.global_batch_size num_shards = ctx._config._tpu_config.num_shards loop_vars += model_fn_wrapper._eval_cache_fn(batch_size // num_shards) return training_loop.repeat( iterations_per_loop_var, single_tpu_eval_step, loop_vars) ret = tpu.shard( multi_tpu_eval_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) loss = ret[0] scaffold = _get_scaffold(captured_scaffold_fn) return loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() (single_tpu_train_step, host_call, captured_scaffold_fn, captured_training_hooks) = ( model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): if model_fn_wrapper._params.get('track_mean', False): loop_vars = [_ZERO_LOSS] else: loop_vars = [_INITIAL_LOSS] if model_fn_wrapper._train_cache_fn is not None: batch_size = ctx.global_batch_size num_shards = ctx._config._tpu_config.num_shards loop_vars += model_fn_wrapper._train_cache_fn(batch_size // num_shards) return training_loop.repeat( iterations_per_loop_var, single_tpu_train_step, loop_vars) ret = tpu.shard( multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) loss = ret[0] scaffold = _get_scaffold(captured_scaffold_fn) return loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" (single_tpu_predict_step, host_calls, captured_scaffold_fn, captured_predict_hooks ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) def multi_tpu_predict_steps_on_single_shard(): def cond(scalar_stopping_signal): return math_ops.logical_not( _StopSignals.should_stop(scalar_stopping_signal)) inputs = [_StopSignals.NON_STOPPING_SIGNAL] outputs = training_loop.while_loop( cond, single_tpu_predict_step, inputs=inputs, name=b'loop') return outputs (dummy_predict_op,) = tpu.shard( multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() def _wrap_computation_in_while_loop(device, op_fn): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def computation(i): with ops.control_dependencies(op_fn()): return i + 1 iterations_per_loop_var = _create_or_get_iterations_per_loop() # By setting parallel_iterations=1, the parallel execution in while_loop is # basically turned off. with ops.DEVICE(device): iterations = array_ops.identity(iterations_per_loop_var) return control_flow_ops.while_loop( lambda i: i < iterations, computation, [constant_op.constant(0)], parallel_iterations=1) def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def cond(scalar_stopping_signal): return math_ops.logical_not( _StopSignals.should_stop(scalar_stopping_signal)) def computation(unused_scalar_stopping_signal): return_value = op_fn() execute_ops = return_value['ops'] signals = return_value['signals'] with ops.control_dependencies(execute_ops): return _StopSignals.as_scalar_stopping_signal(signals) # By setting parallel_iterations=1, the parallel execution in while_loop is # basically turned off. with ops.DEVICE(device): return control_flow_ops.while_loop( cond, computation, [_StopSignals.NON_STOPPING_SIGNAL], parallel_iterations=1) def _validate_tpu_training_graph(): """Validate graph before running distributed training. Raises: ValueError: If the graph seems invalid for running on DEVICE """ operations = ops.get_default_graph().get_operations() # Check if there is atleast one CrossReplicaSum operation in the graph # This should be introduced by using the CrossShardOptimizer wrapper cross_replica_sum_ops = [ o for o in operations if o.type == _CROSS_REPLICA_SUM_OP ] if not cross_replica_sum_ops: raise ValueError( 'CrossShardOptimizer must be used for model training on TPUs.') class _CapturedObject(object): """A placeholder to capture an object. This is useful when we need to capture a Python object in the Tensorflow control flow body function and use it outside the control flow. """ def __init__(self): self._object = None self._captured = False def capture(self, o): if self._captured: raise RuntimeError( 'InternalError: Object can capture only once. Please file bug.') self._captured = True self._object = o def get(self): if not self._captured: raise RuntimeError( 'InternalError: Object is not captured properly before `get`. ' 'Please file bug.') return self._object def _get_scaffold(captured_scaffold_fn): """Retrieves the Scaffold from `captured_scaffold_fn`.""" with _CapturingContext(message='Inside scaffold_fn'): scaffold_fn = captured_scaffold_fn.get() if scaffold_fn: scaffold = scaffold_fn() if scaffold is None: raise ValueError( 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') else: scaffold = None if scaffold: wrapped_finalize = scaffold.finalize def _finalize(): with _CapturingContext('Inside Scaffold.finalize'): wrapped_finalize() scaffold.finalize = _finalize return scaffold class _CapturingContext(control_flow_ops.ControlFlowContext): """Tracks references to Tensors defined in TPU replication.""" def __init__(self, message): control_flow_ops.ControlFlowContext.__init__(self) self._message = message def AddOp(self, op): # pylint: disable=invalid-name for c in op.inputs: if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access raise ValueError('{}: Op {} depends on TPU computation {}, ' 'which is not allowed.'.format(self._message, op, c)) def to_control_flow_context_def(self, context_def, export_scope=None): # pylint: disable=useless-super-delegation # NOTE(slebedev): the method is required by `ControlFlowContext`. super(_CapturingContext, self).to_control_flow_context_def( context_def, export_scope) def __enter__(self): # pylint: disable=protected-access self._g = ops.get_default_graph() self._old = self._g._get_control_flow_context() self._g._set_control_flow_context(self) # pylint: enable=protected-access def __exit__(self, _, __, ___): # pylint: disable=invalid-name self._g._set_control_flow_context(self._old) # pylint: disable=protected-access class _Inputs(object): """A data structure representing the input_fn returned values. This also supports the returned value from input_fn as `Dataset`. """ def __init__(self, features=None, labels=None, dataset=None, signals=None): if dataset is not None and (features is not None or labels is not None or signals is not None): raise RuntimeError('Internal Error: Either (features and labels) or ' 'dataset should be provided, not both. Please file ' 'bug') self._features = features self._labels = labels self._signals = signals self._dataset = dataset self._iterator = None @staticmethod def from_input_fn(return_values): """Returns an `_Inputs` instance according to `input_fn` return value.""" if isinstance(return_values, dataset_ops.Dataset): dataset = return_values return _Inputs(dataset=dataset) features, labels = _Inputs._parse_inputs(return_values) return _Inputs(features, labels) @staticmethod def _parse_inputs(return_values): if isinstance(return_values, tuple): features, labels = return_values else: features, labels = return_values, None return features, labels @property def is_dataset(self): """Returns True if the return value from input_fn is Dataset.""" return self._dataset is not None def dataset_initializer_hook(self): """Returns a `SessionRunHook` to initialize this dataset. This must be called before `features_and_labels`. """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access hook = estimator_util._DatasetInitializerHook(iterator) # pylint: enable=protected-access self._iterator = iterator return hook def features_and_labels(self): """Gets `features` and `labels`.""" if self.is_dataset: if self._iterator is None: raise RuntimeError('Internal error: Must call dataset_initializer_hook ' 'before calling features_and_labels(). Please file ' 'a bug!') return _Inputs._parse_inputs(self._iterator.get_next()) return (self._features, self._labels) def signals(self): return self._signals @property def dataset(self): return self._dataset class _InputsWithStoppingSignals(_Inputs): """Inputs with `_StopSignals` inserted into the dataset.""" def __init__(self, dataset, batch_size, add_padding=False, num_invocations_per_step=1): assert dataset is not None user_provided_dataset = dataset.map( _InputsWithStoppingSignals.insert_stopping_signal( stop=False, batch_size=batch_size, add_padding=add_padding)) if num_invocations_per_step == 1: final_batch_dataset = dataset.take(1).map( _InputsWithStoppingSignals.insert_stopping_signal( stop=True, batch_size=batch_size, add_padding=add_padding)) else: # We append (2 * num_invocations_per_step - 1) batches for exhausting the # user_provided_dataset and stop properly. # For example, if num_invocations_per_step is 2, we append 3 additional # padding batches: b1, b2, b3. # If user_provided_dataset contains two batches: a1, a2 # Step 1: [a1, a2] # Step 2: [b1, b2] -> STOP # If user_provided_dataset contains three batches: a1, a2, a3. # The training loops: # Step 1: [a1, a2] # Step 2: [a3, b1] # Step 3: [b2, b3] -> STOP. final_batch_dataset = dataset.take(1).map( _InputsWithStoppingSignals.insert_stopping_signal( stop=True, batch_size=batch_size, add_padding=add_padding)) final_batch_dataset = final_batch_dataset.repeat( 2 * num_invocations_per_step - 1) def _set_mask(data_dict): signals = data_dict['signals'] signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) data_dict['signals'] = signals return data_dict # Mask out the extra batch. final_batch_dataset = final_batch_dataset.map(_set_mask) dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) self._current_inputs = None def features_and_labels(self): if self._current_inputs is not None: raise RuntimeError( 'Internal Error: The previous inputs have not been properly ' 'consumed. First call features_and_labels, then call signals.') inputs_with_signals = self._iterator.get_next() features = inputs_with_signals['features'] labels = inputs_with_signals.get('labels') self._current_inputs = inputs_with_signals return features, labels def signals(self): """Returns the `Signals` from `_Inputs`.""" if self._current_inputs is None: raise RuntimeError( 'Internal Error: The current inputs have not been properly ' 'generated. First call features_and_labels, then call signals.') signals = self._current_inputs['signals'] self._current_inputs = None return signals @staticmethod def insert_stopping_signal(stop, batch_size, add_padding=False): """Inserts stopping_signal into dataset via _map_fn. Here we change the data structure in the dataset, such that the return value is a dictionary now and `features`, `labels`, and `signals` are three distinguished keys in that dict. This provides a better structure, which eases the process to decompose the inputs (see `features_and_labels`). Args: stop: bool, state of current stopping signals. batch_size: int, batch size. add_padding: bool, whether to pad the tensor to full batch size. Returns: A map_fn passed to dataset.map API. """ def _map_fn(*args): """The map fn to insert signals.""" if len(args) == 1: # Unpack the single Tensor/dict argument as features. This is required # for the input_fn returns no labels. args = args[0] features, labels = _Inputs._parse_inputs(args) new_input_dict = {} if add_padding: padding_mask, features, labels = ( _PaddingSignals.pad_features_and_labels( features, labels, batch_size)) new_input_dict['features'] = features if labels is not None: new_input_dict['labels'] = labels else: new_input_dict['features'] = features if labels is not None: new_input_dict['labels'] = labels padding_mask = None new_input_dict['signals'] = _StopSignals( stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict() return new_input_dict return _map_fn class _StopSignals(object): """Signals class holding all logic to handle TPU stopping condition.""" NON_STOPPING_SIGNAL = False STOPPING_SIGNAL = True def __init__(self, stop, batch_size, padding_mask=None): self._stop = stop self._batch_size = batch_size self._padding_mask = padding_mask def as_dict(self): """Returns the signals as Python dict.""" shape = [self._batch_size, 1] dtype = dtypes.bool if self._stop: stopping = array_ops.ones(shape=shape, dtype=dtype) else: stopping = array_ops.zeros(shape=shape, dtype=dtype) signals = {'stopping': stopping} if self._padding_mask is not None: signals['padding_mask'] = self._padding_mask return signals @staticmethod def as_scalar_stopping_signal(signals): return array_ops.identity(signals['stopping'][0][0]) @staticmethod def should_stop(scalar_stopping_signal): """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. return math_ops.logical_and( scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) else: # For non Tensor case, it is used in SessionRunHook. So, we cannot modify # the graph anymore. Here, we use pure Python. return bool(scalar_stopping_signal) class _PaddingSignals(object): """Signals class holding all logic to handle padding.""" @staticmethod def pad_features_and_labels(features, labels, batch_size): """Pads out the batch dimension of features and labels.""" real_batch_size = array_ops.shape( _PaddingSignals._find_any_tensor(features))[0] batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) check_greater = check_ops.assert_greater_equal( batch_size_tensor, real_batch_size, data=(batch_size_tensor, real_batch_size), message='The real batch size should not be greater than batch_size.') with ops.control_dependencies([check_greater]): missing_count = batch_size_tensor - real_batch_size def pad_single_tensor(tensor): """Pads out the batch dimension of a tensor to the complete batch_size.""" rank = len(tensor.shape) assert rank > 0 padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) padded_shape = (batch_size,) + tuple(tensor.shape[1:]) padded_tensor = array_ops.pad(tensor, padding) padded_tensor.set_shape(padded_shape) return padded_tensor def nest_pad(tensor_or_dict): return nest.map_structure(pad_single_tensor, tensor_or_dict) features = nest_pad(features) if labels is not None: labels = nest_pad(labels) padding_mask = _PaddingSignals._padding_mask( real_batch_size, missing_count, batch_size) return padding_mask, features, labels @staticmethod def slice_tensor_or_dict(tensor_or_dict, signals): """Slice the real Tensors according to padding mask in signals.""" padding_mask = signals['padding_mask'] batch_size = array_ops.shape(padding_mask)[0] def verify_batch_size(tensor): check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) with ops.control_dependencies([check_batch_size]): return array_ops.identity(tensor) def slice_single_tensor(tensor): rank = len(tensor.shape) assert rank > 0 real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) return verify_batch_size(tensor)[0:real_batch_size] # As we split the Tensors to all TPU cores and concat them back, it is # important to ensure the real data is placed before padded ones, i.e., # order is preserved. By that, the sliced padding mask should have all 0's. # If this assertion failed, # the slice logic here would not hold. sliced_padding_mask = slice_single_tensor(padding_mask) assert_padding_mask = math_ops.equal( math_ops.reduce_sum(sliced_padding_mask), 0) with ops.control_dependencies([assert_padding_mask]): should_stop = _StopSignals.should_stop( _StopSignals.as_scalar_stopping_signal(signals)) is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) def slice_fn(tensor): # If the current batch is full batch or part of stopping signals, we do # not need to slice to save performance. return control_flow_ops.cond( math_ops.logical_or(should_stop, is_full_batch), (lambda: verify_batch_size(tensor)), (lambda: slice_single_tensor(tensor))) return nest.map_structure(slice_fn, tensor_or_dict) @staticmethod def _find_any_tensor(batch_features): tensors = [x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor)] if not tensors: raise ValueError('Cannot find any Tensor in features dict.') return tensors[0] @staticmethod def _padding_mask(real_batch_size, missing_count, batch_size): padding_mask = array_ops.concat( [ array_ops.zeros((real_batch_size,), dtype=dtypes.int32), array_ops.ones((missing_count,), dtype=dtypes.int32) ], axis=0) padding_mask.set_shape((batch_size,)) return padding_mask def _verify_cross_hosts_transfer_size(tensor_dict, message): total_size = 0 tensor_structure = {} for key, tensor in tensor_dict.items(): shape = tensor.shape size = np.product(shape) * tensor.dtype.size tensor_structure[key] = shape total_size += size if total_size >= _ONE_GIGABYTE: raise ValueError( '{} The transfer size is larger than the protobuf limit. Please ' 'consider to use Tensors with smaller shapes or reduce batch ' 'size. Given:\n' '{}'.format(message, '\n'.join([ ' -- Key: {}, Shape: {}'.format(k, v) for k, v in tensor_structure.items()]))) def _add_item_to_params(params, key, value): """Adds a new item into `params`.""" if isinstance(params, hparam.HParams): # For HParams, we need to use special API. if key in params: params.set_hparam(key, value) else: params.add_hparam(key, value) else: # Now params is Python dict. params[key] = value def export_estimator_savedmodel(estimator, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None, strip_default_attrs=False): """Export `Estimator` trained model for TPU inference. Args: estimator: `Estimator` with which model has been trained. export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver` or `TensorServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. Returns: The string path to the exported directory. """ # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use # `estimator.config`. config = tpu_config.RunConfig(model_dir=estimator.model_dir) est = TPUEstimator( estimator._model_fn, # pylint: disable=protected-access config=config, params=estimator.params, use_tpu=True, train_batch_size=2048, # Does not matter. eval_batch_size=2048, # Does not matter. ) return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, assets_extra, as_text, checkpoint_path, strip_default_attrs)