This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../transformer-xl/tf/tpu_estimator.py
2025-11-25 20:20:08 +01:00

3519 lines
137 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPUEstimator class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import os
import signal
import sys
import threading
import time
import numpy as np
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from six.moves import xrange # pylint: disable=redefined-builtin
import math
try:
import google3
from google3.third_party.tensorflow.contrib.tpu.python.ops import tpu_ops
from google3.third_party.tensorflow.contrib.tpu.python.tpu import error_handling
from google3.third_party.tensorflow.contrib.tpu.python.tpu import session_support
from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu
from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_config
from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_context
from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_feed
from google3.third_party.tensorflow.contrib.tpu.python.tpu import training_loop
from google3.third_party.tensorflow.contrib.tpu.python.tpu import util as util_lib
from google3.third_party.tensorflow.contrib.training.python.training import hparam
from google3.third_party.tensorflow.core.framework import variable_pb2
from google3.third_party.tensorflow.core.framework.summary_pb2 import Summary
from google3.third_party.tensorflow.core.protobuf import config_pb2
from google3.third_party.tensorflow.python.data.ops import dataset_ops
from google3.third_party.tensorflow.python.data.util import nest as data_nest
from google3.third_party.tensorflow.python.estimator import estimator as estimator_lib
from google3.third_party.tensorflow.python.estimator import model_fn as model_fn_lib
from google3.third_party.tensorflow.python.estimator.export import export_output as export_output_lib
from google3.third_party.tensorflow.python.framework import constant_op
from google3.third_party.tensorflow.python.framework import dtypes
from google3.third_party.tensorflow.python.framework import errors
from google3.third_party.tensorflow.python.framework import ops
from google3.third_party.tensorflow.python.ops import array_ops
from google3.third_party.tensorflow.python.ops import check_ops
from google3.third_party.tensorflow.python.ops import control_flow_ops
from google3.third_party.tensorflow.python.ops import init_ops
from google3.third_party.tensorflow.python.ops import math_ops
from google3.third_party.tensorflow.python.ops import resource_variable_ops
from google3.third_party.tensorflow.python.ops import state_ops
from google3.third_party.tensorflow.python.ops import summary_ops_v2 as contrib_summary
from google3.third_party.tensorflow.python.ops import variable_scope
from google3.third_party.tensorflow.python.ops import variables
from google3.third_party.tensorflow.python.platform import tf_logging as logging
from google3.third_party.tensorflow.python.saved_model import tag_constants
from google3.third_party.tensorflow.python.summary import summary
from google3.third_party.tensorflow.python.training import basic_session_run_hooks
from google3.third_party.tensorflow.python.training import evaluation
from google3.third_party.tensorflow.python.training import session_run_hook
from google3.third_party.tensorflow.python.training import training
from google3.third_party.tensorflow.python.training import training_util
from google3.third_party.tensorflow.python.util import function_utils
from google3.third_party.tensorflow.python.util import nest
from google3.third_party.tensorflow.python.util import tf_inspect
except:
import tensorflow
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import error_handling
from tensorflow.contrib.tpu.python.tpu import session_support
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_context
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.contrib.training.python.training import hparam
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2 as contrib_summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'custom_tpu_estimator' # CHANGE FOR RECURRENCY
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CTX_KEY = 'context'
_USE_TPU_KEY = 'use_tpu'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
# Ideally _USE_TPU_KEY should be reserved as well. However there are already
# models that make use of this key, thus it can not be reserved now to prevent
# breakage. In the long run, we would like to mitigate this by migrating models
# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
# only used for per-core based deployments. For per-host based pipelines, if a
# user returns a Dataset instance it will be automatically wrapped in a
# tf.while_loop (This can be disabled by returning features and labels
# explicitly).
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
ops.register_proto_function(
'{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
proto_type=variable_pb2.VariableDef,
to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access
from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access
def _create_global_step(graph):
graph = graph or ops.get_default_graph()
if training.get_global_step(graph) is not None:
raise ValueError('"global_step" already exists.')
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
use_resource=True,
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
def _create_or_get_iterations_per_loop():
"""Creates or gets the iterations_per_loop variable.
In TPUEstimator, the user provided computation, the model_fn, is wrapped
inside a tf.while_loop for peak performance. The iterations of the loop are
specified by this variable, which adjusts its value on the CPU after each TPU
program execution and before the next TPU execution.
The purpose of using a variable, rather then a constant, is to allow
TPUEstimator adapt the TPU training iterations according to the final steps
specified by users. For example, if the user sets the iterations_per_loop as 4
in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
variable will have the following value before each TPU training.
- 1-th TPU execution: iterations_per_loop = 4
- 2-th TPU execution: iterations_per_loop = 4
- 3-th TPU execution: iterations_per_loop = 2
As model_fn increases the global step once per train_op invocation, the global
step is 10 after all TPU executions, matching the steps=10 inputs passed in by
users.
Returns:
A TF non-trainable resource variable.
Raises:
RuntimeError: If multi iterations_per_loop variables were found.
"""
graph = ops.get_default_graph()
collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
iter_vars = graph.get_collection(collection_name)
if len(iter_vars) == 1:
return iter_vars[0]
elif len(iter_vars) > 1:
raise RuntimeError('Multiple iterations_per_loop_var in collection.')
with ops.colocate_with(training_util.get_global_step()):
with variable_scope.variable_scope(
_TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
return variable_scope.get_variable(
_ITERATIONS_PER_LOOP_VAR,
initializer=init_ops.zeros_initializer(),
shape=[],
dtype=dtypes.int32,
trainable=False,
collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
use_resource=True)
def _sync_variables_ops():
# Gets the variables back from TPU nodes. This means the variables updated
# by TPU will now be *synced* to host memory.
return [
array_ops.check_numerics(v.read_value(),
'Gradient for %s is NaN' % v.name).op
for v in variables.trainable_variables()
]
def _increase_eval_step_op(iterations_per_loop):
"""Returns an op to increase the eval step for TPU evaluation.
Args:
iterations_per_loop: Tensor. The number of eval steps running in TPU
system before returning to CPU host for each `Session.run`.
Returns:
An operation
"""
eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access
# Estimator evaluate increases 1 by default. So, we increase the difference.
return state_ops.assign_add(
eval_step,
math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype),
use_locking=True)
def _extract_key_names(tensor_or_dict):
if isinstance(tensor_or_dict, dict):
return sorted(tensor_or_dict.keys())
return []
class _SIGNAL(object):
"""Signal used to control the thread of infeed/outfeed.
All preserved signals must be negative numbers. Positive numbers are used to
indicate the number of iterations for next training/evaluation loop.
"""
NEXT_BATCH = -1
STOP = -2
class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
`export_outputs`.
For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
@{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
size is the first dimension. Once all tensors are available at CPU host from
all shards, they are concatenated (on CPU) and passed as positional arguments
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
`eval_metrics`.
`scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
function should not capture any Tensors in `model_fn`.
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
to pass to that function and returns a list of Tensors. `host_call` currently
works for train() and evaluate(). The Tensors returned by the function is
executed on the CPU on every step, so there is communication overhead when
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
summaries with @{tf.contrib.summary.create_file_writer}.
"""
def __new__(cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
host_call=None,
training_hooks=None,
evaluation_hooks=None,
prediction_hooks=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls = {}
if eval_metrics is not None:
host_calls['eval_metrics'] = eval_metrics
if host_call is not None:
host_calls['host_call'] = host_call
_OutfeedHostCall.validate(host_calls)
training_hooks = list(training_hooks or [])
evaluation_hooks = list(evaluation_hooks or [])
prediction_hooks = list(prediction_hooks or [])
for hook in training_hooks + evaluation_hooks + prediction_hooks:
if not isinstance(hook, session_run_hook.SessionRunHook):
raise TypeError(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
return super(TPUEstimatorSpec, cls).__new__(
cls,
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metrics=eval_metrics,
export_outputs=export_outputs,
scaffold_fn=scaffold_fn,
host_call=host_call,
training_hooks=training_hooks,
evaluation_hooks=evaluation_hooks,
prediction_hooks=prediction_hooks)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
host_calls = {}
if self.eval_metrics is not None:
host_calls['eval_metrics'] = self.eval_metrics
if self.host_call is not None:
host_calls['host_call'] = self.host_call
host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
eval_metric_ops = None
if self.eval_metrics is not None:
eval_metric_ops = host_call_ret['eval_metrics']
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
hooks = list(hooks or [])
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(
mode=self.mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
scaffold=scaffold,
training_hooks=self.training_hooks + hooks,
evaluation_hooks=self.evaluation_hooks + hooks,
prediction_hooks=self.prediction_hooks + hooks)
class _OpQueueContext(object):
"""Manages work queue and thread for a infeed/outfeed thread."""
def __init__(self, name, target, args):
self._name = name
self._queue = Queue.Queue()
args = (self,) + args
self._thread = threading.Thread(name=name, target=target, args=args)
self._thread.daemon = True
self._thread.start()
def stop(self):
self._queue.put(_SIGNAL.STOP)
def send_next_batch_signal(self, iterations):
self._queue.put(iterations)
def read_iteration_counts(self):
while True:
iterations = self._queue.get(block=True)
logging.debug('%s read iterations %s', self._name, iterations)
if iterations == _SIGNAL.STOP:
logging.info('%s received shutdown signal, stopping.', self._name)
return
yield iterations
def join(self):
logging.info('Shutting down %s thread.' % self._name)
self.stop()
self._thread.join()
class _OpSignalOnceQueueContext(_OpQueueContext):
"""Manages work queue and thread for a infeed/outfeed thread.
This subclass only signals once.
"""
def __init__(self, name, target, args):
super(_OpSignalOnceQueueContext, self).__init__(name, target, args)
self._has_signaled = False
def send_next_batch_signal(self, iterations):
if not self._has_signaled:
self._queue.put(iterations)
self._has_signaled = True
class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
"""A Session hook setting up the TPU initialization, infeed, and outfeed.
This hook does two major things:
1. initialize and shutdown TPU system.
2. launch and join the threads for infeed enqueue and (optional) outfeed
dequeue.
"""
def __init__(self,
ctx,
enqueue_ops,
dequeue_ops,
run_infeed_loop_on_coordinator=True,
rendezvous=None):
self._master_job = ctx.master_job
self._enqueue_ops = enqueue_ops
self._dequeue_ops = dequeue_ops
self._rendezvous = rendezvous
self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
self._initial_infeed_sleep_secs = (
ctx.config.tpu_config.initial_infeed_sleep_secs)
self._feed_error = None
self._finished = False
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
self._init_ops = [tpu.initialize_system(job=self._master_job)]
self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)
# Get all the writer resources from the initializer, so we know what to
# flush.
for op in summary_writer_init_ops:
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
def _run_infeed(self, queue_ctx, session):
logging.info('Starting infeed thread controller.')
if self._initial_infeed_sleep_secs:
logging.info('%s thread sleeping for %d seconds.', self._name,
self._initial_infeed_sleep_secs)
time.sleep(self._initial_infeed_sleep_secs)
logging.info('%s thread starting after sleep', self._name)
with self._rendezvous.catch_errors(source='infeed', session=session):
if self._run_infeed_loop_on_coordinator:
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
session.run(self._enqueue_ops)
else:
for _ in queue_ctx.read_iteration_counts():
session.run(self._enqueue_ops)
logging.info('Infeed thread finished, shutting down.')
def _run_outfeed(self, queue_ctx, session):
logging.info('Starting outfeed thread controller.')
with self._rendezvous.catch_errors(source='outfeed', session=session):
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
session.run(self._dequeue_ops)
logging.info('Outfeed thread finished, shutting down.')
def _create_infeed_controller(self, name, target, args):
return _OpQueueContext(name=name, target=target, args=args)
def after_create_session(self, session, coord):
logging.info('Init TPU system')
session.run(self._init_ops,
options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
self._infeed_controller = self._create_infeed_controller(
name='InfeedController', target=self._run_infeed, args=(session,))
self._outfeed_controller = _OpQueueContext(
name='OutfeedController', target=self._run_outfeed, args=(session,))
def before_run(self, run_context):
self._feed_error = None
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
self._infeed_controller.send_next_batch_signal(iterations)
logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
iterations)
self._outfeed_controller.send_next_batch_signal(iterations)
def end(self, session):
self._finished = True
logging.info('Stop infeed thread controller')
self._infeed_controller.join()
self._rendezvous.record_done('infeed')
logging.info('Stop output thread controller')
self._outfeed_controller.join()
self._rendezvous.record_done('outfeed')
logging.info('Shutdown TPU system.')
session.run(self._finalize_ops)
class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):
def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None):
super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False,
rendezvous=rendezvous)
def _create_infeed_controller(self, name, target, args):
return _OpSignalOnceQueueContext(name=name, target=target, args=args)
class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step.
This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with
following differences for TPU training:
1. This hook sets the variable for iterations_per_loop, which is used by
`TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.
As the hook execution order is not guaranteed, the variable update is
handled in `after_create_session` and `after_run` as
`TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.
2. For each training loop (session.run), the global step could be increased
multiple times on TPU. The global step tensor value will be explicitly read
again in `after_run` to ensure the latest value is retrieved to avoid race
condition.
"""
def __init__(self, iterations, num_steps=None, last_step=None):
"""Initializes a `StopAtStepHook`.
Args:
iterations: The number of iterations to run optimizer per training loop.
num_steps: Number of steps to execute.
last_step: Step after which to stop.
Raises:
ValueError: If one of the arguments is invalid.
"""
if num_steps is None and last_step is None:
raise ValueError('One of num_steps or last_step must be specified.')
if num_steps is not None and last_step is not None:
raise ValueError('Only one of num_steps or last_step can be specified.')
self._num_steps = num_steps
self._last_step = last_step
self._iterations = iterations
def _next_iterations(self, global_step, last_step):
gap = last_step - global_step
return min(gap, self._iterations)
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError('Global step should be created.')
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
if self._last_step is None:
self._last_step = global_step + self._num_steps
iterations = self._next_iterations(global_step, self._last_step)
self._iterations_per_loop_var.load(iterations, session=session)
def after_run(self, run_context, run_values):
# Global step cannot be retrieved via SessionRunArgs and before_run due to
# race condition.
global_step = run_context.session.run(self._global_step_tensor)
if global_step >= self._last_step:
run_context.request_stop()
else:
iterations = self._next_iterations(global_step, self._last_step)
self._iterations_per_loop_var.load(
iterations, session=run_context.session)
class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, num_steps):
"""Initializes a `_SetEvalIterationsHook`.
Args:
num_steps: Number of steps to execute.
"""
self._num_steps = num_steps
def begin(self):
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
def after_create_session(self, session, coord):
self._iterations_per_loop_var.load(self._num_steps, session=session)
class _StoppingPredictHook(session_run_hook.SessionRunHook):
"""Hook that requests stop according to the stopping signal in prediction."""
def __init__(self, scalar_stopping_signal):
self._scalar_stopping_signal = scalar_stopping_signal
def begin(self):
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
def after_create_session(self, session, coord):
# This is not necessary as we do not run infeed enqueue and outfeed dequeue
# in side threads for prediction model. But it makes the
# TPUInfeedOutfeedSessionHook prints nice message.
self._iterations_per_loop_var.load(1, session=session)
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(self._scalar_stopping_signal)
def after_run(self, run_context, run_values):
_ = run_context
scalar_stopping_signal = run_values.results
if _StopSignals.should_stop(scalar_stopping_signal):
# NOTE(xiejw): In prediction, stopping signals are inserted for each
# batch. And we append one more batch to signal the system it should stop.
# The data flow might look like
#
# batch 0: images, labels, stop = 0 (user provided)
# batch 1: images, labels, stop = 0 (user provided)
# ...
# batch 99: images, labels, stop = 0 (user provided)
# batch 100: images, labels, stop = 1 (TPUEstimator appended)
#
# where the final batch (id = 100) is appended by TPUEstimator, so we
# should drop it before returning the predictions to user.
# To achieve that, we throw the OutOfRangeError in after_run. Once
# Monitored Session sees this error in SessionRunHook.after_run, the
# "current" prediction, i.e., batch with id=100, will be discarded
# immediately
raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
def generate_per_core_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, host_device, host_id):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
"""A fn returns enqueue_ops."""
num_cores_per_host = ctx.num_of_cores_per_host
per_host_sharded_inputs = []
for core_ordinal in range(num_cores_per_host):
with ops.name_scope('ordinal_%d' % (core_ordinal)):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=host_device,
invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal
)
inputs = _Inputs.from_input_fn(input_fn(user_context))
if inputs.is_dataset:
raise TypeError(
'`input_fn` returning `Dataset` is not yet supported in '
'per-Core input pipeline deployment yet. Please set '
'TPUConfig.per_host_input_for_training to True or return '
'`features` and `labels` from `input_fn`')
features, labels = inputs.features_and_labels()
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels))
per_host_sharded_inputs.append(flattened_inputs)
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
return per_host_enqueue_ops
return enqueue_ops_fn, captured_infeed_queue
def generate_per_host_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
hooks = []
with ops.DEVICE(device):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=device,
invocation_index=host_id)
inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
if not is_dataset:
raise TypeError(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.')
if batch_axis is not None:
raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
inputs = _InputsWithStoppingSignals(
dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn,
add_padding=True)
if is_dataset:
hooks.append(inputs.dataset_initializer_hook())
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
"""A Fn returning the TPU infeed enqueue ops.
By providing as a Fn, it can be invoked inside the tf.while_loop such that
the input pipeline for multiple iterations can be executed by one
Session.run call.
Returns:
list of dict of ops.
"""
with ops.DEVICE(device):
num_of_replicas_per_host = ctx.num_of_replicas_per_host
# Convert user input to features and labels. If the user returns a
# dataset, it is initialized and the features and labels extracted via
# `dataset.iterator.get_next()`
features, labels = inputs.features_and_labels()
signals = inputs.signals()
inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
infeed_queue = tpu_feed.InfeedQueue(
tuple_types=[t.dtype for t in unsharded_tensor_list],
tuple_shapes=[t.shape for t in unsharded_tensor_list],
shard_dimensions=batch_axis)
captured_infeed_queue.capture(infeed_queue)
infeed_queue.set_number_of_shards(num_of_replicas_per_host)
per_host_enqueue_ops = (
infeed_queue.split_inputs_and_generate_enqueue_ops(
unsharded_tensor_list,
placement_function=lambda x: device,
tpu_ordinal_function=tpu_ordinal_function_impl))
if signals is None:
return per_host_enqueue_ops
else:
return {
'ops': per_host_enqueue_ops,
'signals': signals,
}
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
def generate_per_host_v2_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, device, host_id):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
hooks = []
with ops.DEVICE(device):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=device,
invocation_index=host_id)
inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if not is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
'input pipeline configuration.')
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
inputs = _InputsWithStoppingSignals(
dataset=inputs.dataset,
batch_size=ctx.batch_size_for_input_fn,
add_padding=True,
num_invocations_per_step=ctx.num_of_replicas_per_host)
hooks.append(inputs.dataset_initializer_hook())
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
"""Generates the per_host enqueue ops."""
control_deps = []
per_host_sharded_inputs = []
num_replicas_per_host = ctx.num_of_replicas_per_host
cached_signals = None
with ops.DEVICE(device):
if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
for _ in range(num_replicas_per_host):
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
features, labels = inputs.features_and_labels() # Calls get_next()
signals = inputs.signals()
# All the replicas share the replica 0's stopping singal.
# This avoids inconsistent state among different model replcias.
if cached_signals:
signals['stopping'] = cached_signals['stopping']
else:
cached_signals = signals
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
control_deps.extend(flattened_inputs)
per_host_sharded_inputs.append(flattened_inputs)
if inputs_structure_recorder.flattened_input_dims:
input_partition_dims = inputs_structure_recorder.flattened_input_dims
if signals:
input_partition_dims += [None] * len(signals)
# pylint: disable=protected-access
infeed_queue = tpu_feed._PartitionedInfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]),
host_id=host_id,
input_partition_dims=input_partition_dims,
device_assignment=ctx.device_assignment)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs)
else:
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs,
tpu_ordinal_function=tpu_ordinal_function_impl)
captured_infeed_queue.capture(infeed_queue)
if signals is None:
return per_host_enqueue_ops
else:
return {
'ops': per_host_enqueue_ops,
'signals': signals,
}
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
num_hosts):
"""Generates infeed enqueue ops for one input_fn on all the hosts."""
captured_infeed_queue = _CapturedObject()
hooks = []
device_0 = ctx.tpu_host_placement_function(host_id=0)
with ops.DEVICE(device_0):
user_context = tpu_context.TPUContext(
internal_ctx=ctx, input_device=device_0, invocation_index=0)
inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
if not is_dataset:
raise TypeError(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.')
inputs = _InputsWithStoppingSignals(
dataset=inputs.dataset,
batch_size=ctx.batch_size_for_input_fn,
add_padding=True)
if is_dataset:
hooks.append(inputs.dataset_initializer_hook())
num_replicas_per_host = ctx.num_of_replicas_per_host
def tpu_ordinal_function_impl(replica_id):
if ctx.device_assignment:
return ctx.device_assignment.tpu_ordinal(replica=replica_id)
else:
return replica_id % num_replicas_per_host
def device_function_impl(replica_id):
return ctx.tpu_host_placement_function(replica_id=replica_id)
def enqueue_ops_fn():
"""Generates enqueue ops for all the hosts."""
broadcasted_inputs = []
flattened_inputs = None # Cache result from input_fn.
signals = None
for host_id in xrange(num_hosts):
with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)):
for _ in xrange(ctx.num_of_replicas_per_host):
# Note: input_fn is only called once at host 0 for the first replica.
# The features and labels returned from that invocation are
# broadcasted to other replicas(including the replicas on other
# hosts).
if flattened_inputs is None:
features, labels = inputs.features_and_labels() # Calls get_next()
signals = inputs.signals()
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
broadcasted_inputs.append(flattened_inputs)
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(broadcasted_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
enqueue_ops = infeed_queue.generate_enqueue_ops(
broadcasted_inputs,
tpu_ordinal_function=tpu_ordinal_function_impl,
placement_function=device_function_impl)
if signals is None:
return enqueue_ops
else:
return {
'ops': enqueue_ops,
'signals': signals,
}
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
`_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
call site. To be precise, based on the configuration in
`_InternalTPUContext`, it invokes `input_fn` for all cores (usually
multi-host TPU training) or for one host (usually for single-host TPU
evaluation), and sends all `features` and `labels` returned by `input_fn` to
TPU infeed. For per-core invocation, `features` and `labels` are piped to
infeed directly, one tuple for each core. For per-host invocation, `features`
and `labels` are split at host (with respect to `batch_axis`) and piped to all
cores accordingly.
In addition, flatten/unflatten are handled by `_InputPipeline` also. Model
inputs returned by the `input_fn` can have one of the following forms:
1. features
2. (features, labels)
3. ((arbitrarily nested structure of features), labels)
Internally, form 1 is reformed to `(features, None)` as features and labels
are passed separately to underlying methods. For TPU training, TPUEstimator
may expect multiple `features` and `labels` tuples one for each core.
TPUEstimator allows various different structures for inputs (namely `features`
and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`,
or nested tuples and `labels` could be `None`, `Tensor`, or dict of string
name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list.
So, `features` and `labels` need to be flattened, before infeed enqueue, and
the structure of them needs to be recorded, in order to restore them after
infeed dequeue.
"""
class InputsStructureRecorder(object):
"""The recorder to record inputs structure."""
def __init__(self, input_partition_dims=None):
# Holds the structure of inputs
self._feature_structure = {}
self._flattened_input_dims = None
if input_partition_dims:
# This should have been validated in TPUConfig.
assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.'
if len(input_partition_dims) == 2:
self._feature_dims, self._label_dims = input_partition_dims
else:
self._feature_dims = input_partition_dims[0]
self._label_dims = None
assert self._feature_dims is not None, ('input_partition_dims[0] must '
'not be None')
else:
self._feature_dims = None
self._label_dims = None
# Internal state.
self._initialized = False
@property
def flattened_input_dims(self):
assert self._initialized, 'InputsStructureRecorder is not initialized.'
return self._flattened_input_dims
def has_labels(self):
return 'labels' in self._feature_structure
def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims,
label_dims_names, label_names, has_labels):
"""Flatten input dims with the same order as flattened input tensors."""
flattened_input_dims = []
if feature_dims_names:
# We need a fixed ordering for matching the tensors in features.
flattened_input_dims.extend(
[feature_dims[name] for name in feature_dims_names])
else:
flattened_input_dims.append(feature_dims)
if label_dims_names:
# We need a fixed ordering for matching the tensors in labels.
flattened_input_dims.extend(
[label_dims[name] for name in label_dims_names])
else:
if label_names:
num_tensors_in_label = len(label_names)
else:
num_tensors_in_label = int(has_labels)
# Setting `None` in input_partition_dims[1] will apply `None` to
# all the tensors in labels, regardless of internal structure.
flattened_input_dims.extend([label_dims] * num_tensors_in_label)
return flattened_input_dims
def validate_and_record_structure(self, features, labels):
"""Validates and records the structure of `features` and `labels`."""
# Extract structure.
has_labels = labels is not None
feature_names = _extract_key_names(features)
label_names = _extract_key_names(labels)
if not self._initialized:
# Record structure.
self._initialized = True
if self._feature_dims is not None:
feature_dims_names = _extract_key_names(self._feature_dims)
if feature_dims_names != feature_names:
raise ValueError(
'TPUConfig.input_partition_dims[0] mismatched feature'
' keys. Expected {}, got {}'.format(feature_names,
feature_dims_names))
label_dims_names = _extract_key_names(self._label_dims)
if self._label_dims is not None and label_dims_names != label_names:
raise ValueError(
'TPUConfig.input_partition_dims[1] mismatched label'
' keys. Expected {}, got {}'.format(label_names,
label_dims_names))
self._flattened_input_dims = self._flatten_input_dims(
self._feature_dims, feature_dims_names, self._label_dims,
label_dims_names, label_names, has_labels)
def flatten_features_and_labels(self, features, labels, signals=None):
"""Flattens the `features` and `labels` to a single tensor list."""
self._feature_structure['features'] = features
if labels is not None:
self._feature_structure['labels'] = labels
if signals is not None:
self._feature_structure['signals'] = signals
return data_nest.flatten(self._feature_structure)
def unflatten_features_and_labels(self, flattened_inputs):
"""Restores the flattened inputs to original features and labels form.
Args:
flattened_inputs: Flattened inputs for each shard.
Returns:
A tuple of (`features`, `labels`), where `labels` could be None.
Each one, if present, should have identical structure (single tensor vs
dict) as the one returned by input_fn.
Raises:
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
flattened_inputs)
return _Inputs(
unflattened_inputs['features'],
unflattened_inputs.get('labels'),
signals=unflattened_inputs.get('signals'))
def __init__(self, input_fn, batch_axis, ctx):
"""Constructor.
Args:
input_fn: input fn for train or eval.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards.
ctx: A `_InternalTPUContext` instance with mode.
Raises:
ValueError: If both `sharded_features` and `num_cores` are `None`.
"""
self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder(
ctx.input_partition_dims)
self._sharded_per_core = ctx.is_input_sharded_per_core()
self._input_fn = input_fn
self._infeed_queue = None
self._ctx = ctx
self._batch_axis = batch_axis
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
"""Generates infeed enqueue ops and dequeue_fn."""
# While tf.while_loop is called, the body function, which invokes
# `enqueue_fn` passed in, is called to construct the graph. So, input_fn
# structure is recorded.
enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
self._invoke_input_fn_and_record_structure())
self._validate_input_pipeline()
def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors."""
# In the model-parallel case, both the host-side and DEVICE-side
# computations must agree on the core on which infeed takes place. We
# choose to perform infeed on logical core 0 of each replica.
values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
# The unflatten process uses the structure information recorded above.
return self._inputs_structure_recorder.unflatten_features_and_labels(
values)
return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)
def _invoke_input_fn_and_record_structure(self):
"""Deploys the input pipeline and record input structure."""
enqueue_ops = []
infeed_queues = []
all_hooks = []
num_hosts = self._ctx.num_hosts
tpu_host_placement_fn = self._ctx.tpu_host_placement_function
run_infeed_loop_on_coordinator = True
if self._sharded_per_core:
# Per-Core input pipeline deployment.
# Invoke input pipeline for each core and placed on the corresponding
# host.
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host(
self._ctx, self._input_fn, self._inputs_structure_recorder,
host_device, host_id))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
run_infeed_loop_on_coordinator = False
enqueue_ops.append(
_wrap_computation_in_while_loop(
device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(captured_infeed_queue.get())
elif self._ctx.is_input_broadcast_with_iterators():
# Only calls input_fn in host 0.
host_device = tpu_host_placement_fn(host_id=0)
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
self._inputs_structure_recorder,
num_hosts))
all_hooks.extend(hooks)
if is_dataset:
run_infeed_loop_on_coordinator = False
wrap_fn = (
_wrap_computation_in_while_loop
if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
_wrap_computation_in_while_loop_with_stopping_signals)
enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
else:
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
if self._ctx.is_input_per_host_with_iterators():
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
generate_per_host_v2_enqueue_ops_fn_for_host(
self._ctx, self._input_fn,
self._inputs_structure_recorder, host_device, host_id))
else:
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
generate_per_host_enqueue_ops_fn_for_host(
self._ctx, self._input_fn,
self._inputs_structure_recorder, self._batch_axis,
host_device, host_id))
all_hooks.extend(hooks)
# NOTE(xiejw): We dispatch here based on the return type of the
# users `input_fn`.
#
# 1. If input_fn returns a Dataset instance, we initialize the
# iterator outside of tf.while_loop, and call the iterator.get_next
# inside tf.while_loop. This should be always safe.
#
# 2. If input_fn returns (features, labels), it is too late to wrap
# them inside tf.while_loop, as resource initialization cannot be
# handled in TF control flow properly. In this case, we will use
# python loop to enqueue the data into TPU system. This may be
# slow compared to the previous case.
if is_dataset:
run_infeed_loop_on_coordinator = False
wrap_fn = (
_wrap_computation_in_while_loop
if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
_wrap_computation_in_while_loop_with_stopping_signals)
enqueue_ops.append(
wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
# infeed_queue is used to generate dequeue ops. The only thing it uses for
# dequeue is dtypes and types. So, any one can be used. Here, grab the
# first one.
self._infeed_queue = infeed_queues[0]
return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator
def _validate_input_pipeline(self):
"""Validates the input pipeline.
Perform some sanity checks to log user friendly information. We should
error out to give users better error message. But, if
_WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
user code, so, log a warning.
Raises:
RuntimeError: If the validation failed.
"""
if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
err_msg = ('Input pipeline contains one or more QueueRunners. '
'It could be slow and not scalable. Please consider '
'converting your input pipeline to use `tf.data` instead (see '
'https://www.tensorflow.org/guide/datasets for '
'instructions.')
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
raise RuntimeError(err_msg)
else:
logging.warn(err_msg)
class _ModelFnWrapper(object):
"""A `model_fn` wrapper.
This makes calling model_fn on CPU and TPU easier and more consistent and
performs necessary check and mutation required by TPU training and evaluation.
In addition, this wrapper manages converting the `model_fn` to a single TPU
train and eval step.
"""
def __init__(self, model_fn, train_cache_fn, eval_cache_fn, config, params, ctx):
self._model_fn = model_fn
self._train_cache_fn = train_cache_fn
self._eval_cache_fn = eval_cache_fn
self._config = config
self._params = params
self._ctx = ctx
def call_without_tpu(self, features, labels, is_export_mode):
return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
def convert_to_single_tpu_train_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single train step on TPU.
The user provided `model_fn` takes input tuple
(features, labels) and produces the EstimatorSpec with train_op and loss for
train `mode`. This usually represents a single train computation on CPU.
For TPU training, a train (computation) step is first wrapped in a
tf.while_loop control flow to repeat for many times and then replicated to
all TPU shards. Besides the input should be taken from TPU infeed rather
than input pipeline (input_fn) directly. To fit TPU loop and replicate
pattern, the original train computation should be reformed, which is the
returned `train_step`.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn
representing the train step for TPU.
"""
host_call = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
captured_training_hooks = _CapturedObject()
def train_step(loss, *cache):
"""Training step function for use inside a while loop."""
if not self._params.get('track_mean', False):
del loss # unused; required in function signature.
inputs = dequeue_fn()
features, labels = inputs.features_and_labels()
# Consume the current cache
estimator_spec = self._verify_estimator_spec(
self._call_model_fn(features, labels, cache=cache))
# Retrieve the new returned cache
"""
`cache` consists of a list of tensors, potentially empty (of length 0)
"""
cache = estimator_spec.cache
new_loss, train_op = estimator_spec.loss, estimator_spec.train_op
if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
else:
captured_scaffold_fn.capture(None)
captured_training_hooks.capture(estimator_spec.training_hooks)
# We must run train_op to update the variables prior to running the
# outfeed.
with ops.control_dependencies([train_op]):
host_call_outfeed_ops = []
if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access
and estimator_spec.host_call is not None):
host_call.record({'host_call': estimator_spec.host_call})
host_call_outfeed_ops = host_call.create_enqueue_op()
with ops.control_dependencies(host_call_outfeed_ops):
if self._params.get('track_mean', False):
loss = tensorflow.stop_gradient(loss)
return [math_ops.add(loss, new_loss)] + cache
else:
return [array_ops.identity(new_loss)] + cache
return (train_step, host_call, captured_scaffold_fn,
captured_training_hooks)
def convert_to_single_tpu_eval_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single eval step on TPU.
Similar to training, the user provided `model_fn` takes input tuple
(features, labels) and produces the TPUEstimatorSpec with eval_metrics for
eval `mode`. This usually represents a single evaluation computation on CPU.
For TPU evaluation, a eval (computation) step is first wrapped in a
tf.while_loop control flow to repeat for many times and then replicated to
all TPU shards. Besides the input and output are slightly different. Input,
features and labels, should be taken from TPU infeed rather than input
pipeline (input_fn) directly. Output is managed in two stages. First, the
model outputs as the result of evaluation computation, usually model logits,
should be transferred from TPU system to CPU. Then, all model outputs are
concatenated first on CPU and sent to the metric_fn for metrics computation.
To fit TPU evaluation pattern, the original eval computation should be
reformed, which is the returned `eval_step`.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn
representing the eval step for TPU.
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
captured_eval_hooks = _CapturedObject()
def eval_step(total_loss, *cache):
"""Evaluation step function for use inside a while loop."""
inputs = dequeue_fn()
features, labels = inputs.features_and_labels()
# Consume the current cache
tpu_estimator_spec = self._call_model_fn(features, labels, cache=cache)
if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU evaluation must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
# Retrieve the new returned cache
cache = tpu_estimator_spec.cache
loss = tpu_estimator_spec.loss
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)
to_record = {}
if tpu_estimator_spec.eval_metrics:
to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
if tpu_estimator_spec.host_call is not None:
# We assume that evaluate won't update global step, so we don't wrap
# this host_call.
to_record['host_call'] = tpu_estimator_spec.host_call
host_calls.record(to_record)
with ops.control_dependencies(host_calls.create_enqueue_op()):
return [math_ops.add(total_loss, loss)] + cache
return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
def convert_to_single_tpu_predict_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single predict step on TPU.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of predict_fn, host_calls, and captured scaffold_fn. The
predict_fn representing the predict step for TPU.
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
captured_predict_hooks = _CapturedObject()
def predict_step(unused_scalar_stopping_signal):
"""Evaluation step function for use inside a while loop."""
inputs = dequeue_fn()
features, labels = inputs.features_and_labels()
stopping_signals = inputs.signals()
assert stopping_signals is not None, (
'Internal Error: `signals` is missing.')
tpu_estimator_spec = self._call_model_fn(
features, labels, is_export_mode=False)
if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)
to_record = {}
identity_fn = lambda **kwargs: kwargs
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
to_record['signals'] = [identity_fn, stopping_signals]
if tpu_estimator_spec.host_call is not None:
to_record['host_call'] = tpu_estimator_spec.host_call
host_calls.record(to_record)
with ops.control_dependencies(host_calls.create_enqueue_op()):
return _StopSignals.as_scalar_stopping_signal(stopping_signals)
return (predict_step, host_calls, captured_scaffold_fn,
captured_predict_hooks)
def _verify_tpu_spec_predictions(self, predictions):
"""Validates TPUEstimatorSpec.predictions dict."""
# TODO(xiejw): Adds validation for prediction dictionrary.
# TODO(xiejw): Adds support for single tensor as predictions.
if not isinstance(predictions, dict):
raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
for (key, tensor) in predictions.items():
if tensor.shape[0].value is None:
raise ValueError(
'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
'dynamic shape (should be static). Tensor: {}'.format(
key, tensor))
return predictions
def _validate_model_features_and_labels(self,
features,
labels,
is_export_mode):
"""Validates that the features and labels for the model function are valid.
A valid features/labels object is the one with:
- Type: Tensor or a dictionary of Tensors
- Static shape if is_export_mode is False.
Args:
features: the features that would be input to the model function.
labels: the labels that would be input to the model function.
is_export_mode: boolean value specifying if in export mode.
Raises:
TypeError: If features/labels are not of the correct type.
ValueError: If features/labels have dynamic shape.
"""
def validate(obj, obj_name):
"""Helper validate function."""
if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
raise TypeError(
'The {} to the model returned by input_fn must be either a Tensor '
'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
obj))
if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
return
if isinstance(obj, ops.Tensor):
if not obj.get_shape().is_fully_defined():
raise ValueError(
'The {} to the model returned by input_fn must have static shape.'
' Tensor: {}'.format(obj_name, obj))
else:
for (key, value) in obj.items():
flattened_tensors = data_nest.flatten(value)
for tensor in flattened_tensors:
if not tensor.get_shape().is_fully_defined():
raise ValueError(
'The {} to the model returned by input_fn must have static '
'shape. Key: \'{}\', Tensor: {}'.format(
obj_name, key, tensor))
validate(features, 'features')
if labels is not None:
validate(labels, 'labels')
def _call_model_fn(self, features, labels, cache=None, is_export_mode=False):
"""Calls the model_fn with required parameters."""
self._validate_model_features_and_labels(features, labels, is_export_mode)
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
# Makes deep copy with `config` and params` in case user mutates them.
config = copy.deepcopy(self._config)
params = copy.deepcopy(self._params)
if 'labels' in model_fn_args:
kwargs['labels'] = labels
elif labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = self._ctx.mode
if 'config' in model_fn_args:
kwargs['config'] = config
if 'params' in model_fn_args:
kwargs['params'] = params
if cache is not None:
params['cache'] = cache
if 'params' not in model_fn_args:
raise ValueError('model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(self._model_fn))
if is_export_mode:
batch_size_for_model_fn = None
else:
batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
if batch_size_for_model_fn is not None:
_add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
_add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
if not running_on_cpu:
user_context = tpu_context.TPUContext(
internal_ctx=self._ctx, call_from_input_fn=False)
_add_item_to_params(params, _CTX_KEY, user_context)
estimator_spec = self._model_fn(features=features, **kwargs)
if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
return estimator_spec.as_estimator_spec()
else:
return estimator_spec
def _verify_estimator_spec(self, estimator_spec):
"""Validates the estimator_spec."""
if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
return estimator_spec
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
if estimator_spec.training_chief_hooks:
raise ValueError(
err_msg.format('training_chief_hooks') + 'If you want' +
' to pass training hooks, please pass via training_hooks.')
if estimator_spec.scaffold:
logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
'Please use TPUEstimatorSpec.')
return estimator_spec
class _OutfeedHostCall(object):
"""Support for `eval_metrics` and `host_call` in TPUEstimatorSpec."""
def __init__(self, ctx):
self._ctx = ctx
self._names = []
# All of these are dictionaries of lists keyed on the name.
self._host_fns = {}
self._tensor_keys = collections.defaultdict(list)
self._tensors = collections.defaultdict(list)
self._tensor_dtypes = collections.defaultdict(list)
self._tensor_shapes = collections.defaultdict(list)
@staticmethod
def validate(host_calls):
"""Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`."""
for name, host_call in host_calls.items():
if not isinstance(host_call, (tuple, list)):
raise ValueError('{} should be tuple or list'.format(name))
if len(host_call) != 2:
raise ValueError('{} should have two elements.'.format(name))
if not callable(host_call[0]):
raise TypeError('{}[0] should be callable.'.format(name))
if not isinstance(host_call[1], (tuple, list, dict)):
raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))
if isinstance(host_call[1], (tuple, list)):
fullargspec = tf_inspect.getfullargspec(host_call[0])
fn_args = function_utils.fn_args(host_call[0])
# wrapped_hostcall_with_global_step uses varargs, so we allow that.
if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
raise RuntimeError(
'In TPUEstimatorSpec.{}, length of tensors {} does not match '
'method args of the function, which takes {}.'.format(
name, len(host_call[1]), len(fn_args)))
@staticmethod
def create_cpu_hostcall(host_calls):
"""Runs on the host_call on CPU instead of TPU when use_tpu=False."""
_OutfeedHostCall.validate(host_calls)
ret = {}
for name, host_call in host_calls.items():
host_fn, tensors = host_call
if isinstance(tensors, (tuple, list)):
ret[name] = host_fn(*tensors)
else:
# Must be dict.
try:
ret[name] = host_fn(**tensors)
except TypeError as e:
logging.warning(
'Exception while calling %s: %s. It is likely the tensors '
'(%s[1]) do not match the '
'function\'s arguments', name, e, name)
raise e
return ret
def record(self, host_calls):
"""Records the host_call structure."""
for name, host_call in host_calls.items():
host_fn, tensor_list_or_dict = host_call
self._names.append(name)
self._host_fns[name] = host_fn
if isinstance(tensor_list_or_dict, dict):
for (key, tensor) in six.iteritems(tensor_list_or_dict):
self._tensor_keys[name].append(key)
self._tensors[name].append(tensor)
self._tensor_dtypes[name].append(tensor.dtype)
self._tensor_shapes[name].append(tensor.shape)
else:
# List or tuple.
self._tensor_keys[name] = None
for tensor in tensor_list_or_dict:
self._tensors[name].append(tensor)
self._tensor_dtypes[name].append(tensor.dtype)
self._tensor_shapes[name].append(tensor.shape)
def create_enqueue_op(self):
"""Create the op to enqueue the recorded host_calls.
Returns:
A list of enqueue ops, which is empty if there are no host calls.
"""
if not self._names:
return []
tensors = []
# TODO(jhseu): Consider deduping tensors.
for name in self._names:
tensors.extend(self._tensors[name])
with ops.DEVICE(tpu.core(0)):
return [tpu_ops.outfeed_enqueue_tuple(tensors)]
def create_tpu_hostcall(self):
"""Sends the tensors through outfeed and runs the host_fn on CPU.
The tensors are concatenated along dimension 0 to form a global tensor
across all shards. The concatenated function is passed to the host_fn and
executed on the first host.
Returns:
A dictionary mapping name to the return type of the host_call by that
name.
Raises:
RuntimeError: If outfeed tensor is scalar.
"""
if not self._names:
return {}
ret = {}
# For each i, dequeue_ops[i] is a list containing the tensors from all
# shards. This list is concatenated later.
dequeue_ops = []
tensor_dtypes = []
tensor_shapes = []
for name in self._names:
for _ in self._tensors[name]:
dequeue_ops.append([])
for dtype in self._tensor_dtypes[name]:
tensor_dtypes.append(dtype)
for shape in self._tensor_shapes[name]:
tensor_shapes.append(shape)
# Outfeed ops execute on each replica's first logical core. Note: we must
# constraint it such that we have at most one outfeed dequeue and enqueue
# per replica.
for i in xrange(self._ctx.num_replicas):
host_device, ordinal_id = self._ctx.device_for_replica(i)
with ops.DEVICE(host_device):
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
dtypes=tensor_dtypes,
shapes=tensor_shapes,
device_ordinal=ordinal_id)
for j, item in enumerate(outfeed_tensors):
dequeue_ops[j].append(item)
# Deconstruct dequeue ops.
dequeue_ops_by_name = {}
pos = 0
for name in self._names:
dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])]
pos += len(self._tensors[name])
# It is assumed evaluation always happens on single host TPU system. So,
# place all ops on tpu host if possible.
#
# TODO(jhseu): Evaluate whether this is right for summaries.
with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)):
for name in self._names:
dequeue_ops = dequeue_ops_by_name[name]
for i, item in enumerate(dequeue_ops):
if dequeue_ops[i][0].shape.ndims == 0:
raise RuntimeError(
'All tensors outfed from TPU should preserve batch size '
'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
# TODO(xiejw): Allow users to specify the axis for batch size
# dimension.
dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
if self._tensor_keys[name] is not None:
# The user-provided eval_metrics[1] is a dict.
dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))
try:
ret[name] = self._host_fns[name](**dequeue_ops)
except TypeError as e:
logging.warning(
'Exception while calling %s: %s. It is likely the tensors '
'(%s[1]) do not match the '
'function\'s arguments', name, e, name)
raise e
else:
ret[name] = self._host_fns[name](*dequeue_ops)
return ret
class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
"""Hook to run host calls when use_tpu=False."""
def __init__(self, tensors):
self._tensors = tensors
def begin(self):
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
# create a separate hook to guarantee execution order, because summaries
# need to be initialized before the outfeed thread starts.
# TODO(jhseu): Make a wrapper hook instead?
self._init_ops = contrib_summary.summary_writer_initializer_op()
# Get all the writer resources from the initializer, so we know what to
# flush.
self._finalize_ops = []
for op in self._init_ops:
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
def after_create_session(self, session, coord):
session.run(self._init_ops)
def before_run(self, run_context):
return basic_session_run_hooks.SessionRunArgs(self._tensors)
def end(self, session):
session.run(self._finalize_ops)
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
"""Calculate and report global_step/sec and examples/sec during runtime."""
def __init__(self,
batch_size,
every_n_steps=100,
every_n_secs=None,
output_dir=None,
summary_writer=None):
self._batch_size = batch_size
super(ExamplesPerSecondHook, self).__init__(
every_n_steps=every_n_steps,
every_n_secs=every_n_secs,
output_dir=output_dir,
summary_writer=summary_writer)
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
global_step_per_sec = elapsed_steps / elapsed_time
examples_per_sec = self._batch_size * global_step_per_sec
if self._summary_writer is not None:
global_step_summary = Summary(value=[
Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
])
example_summary = Summary(value=[
Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
])
self._summary_writer.add_summary(global_step_summary, global_step)
self._summary_writer.add_summary(example_summary, global_step)
logging.info('global_step/sec: %g', global_step_per_sec)
logging.info('examples/sec: %g', examples_per_sec)
class InstallSignalHandlerHook(session_run_hook.SessionRunHook):
"""Change SIGINT (CTRL^C) handler to force quit the process.
The default behavior often results in hanging processes.
The original handler is restored after training/evaluation.
"""
def __init__(self):
self._signal_fn = signal.getsignal(signal.SIGINT)
def before_run(self, run_context):
signal.signal(signal.SIGINT, signal.SIG_DFL)
def end(self, session):
signal.signal(signal.SIGINT, self._signal_fn)
class TPUEstimator(estimator_lib.Estimator):
"""Estimator with TPU support.
TPUEstimator also supports training on CPU and GPU. You don't need to define
a separate `tf.estimator.Estimator`.
TPUEstimator handles many of the details of running on TPU devices, such as
replicating inputs and models for each core, and returning to host
periodically to run hooks.
TPUEstimator transforms a global batch size in params to a per-shard batch
size when calling the `input_fn` and `model_fn`. Users should specify
global batch size in constructor, and then get the batch size for each shard
in `input_fn` and `model_fn` by `params['batch_size']`.
- For training, `model_fn` gets per-core batch size; `input_fn` may get
per-core or per-host batch size depending on `per_host_input_for_training`
in `TPUConfig` (See docstring for TPUConfig for details).
- For evaluation and prediction, `model_fn` gets per-core batch size and
`input_fn` get per-host batch size.
Evaluation
==========
`model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return
`EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case
the following discussion on TPU evaluation does not apply.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
a dict from metric string name to the result of calling a metric function,
namely a `(metric_tensor, update_op)` tuple.
One can set `use_tpu` to `False` for testing. All training, evaluation, and
predict will be executed on CPU. `input_fn` and `model_fn` will receive
`train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
Current limitations:
--------------------
1. TPU evaluation only works on a single host (one TPU worker) except
BROADCAST mode.
2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
(`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
batches should have the same size.
Example (MNIST):
----------------
```
# The metric Fn which runs on CPU.
def metric_fn(labels, logits):
predictions = tf.argmax(logits, 1)
return {
'accuracy': tf.metrics.precision(
labels=labels, predictions=predictions),
}
# Your model Fn which runs on TPU (eval_metrics is list in this example)
def model_fn(features, labels, mode, config, params):
...
logits = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, [labels, logits]))
# or specify the eval_metrics tensors as dict.
def model_fn(features, labels, mode, config, params):
...
final_layer_output = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, {
'labels': labels,
'logits': final_layer_output,
}))
```
Prediction
==========
Prediction on TPU is an experimental feature to support large batch inference.
It is not designed for latency-critical system. In addition, due to some
usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
convenient.
Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
*should* raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
precise, the ops created by `input_fn` produce one batch of the data.
The `predict()` API processes one batch at a time. When reaching the end of
the data source, an end-of-input exception should be raised by one of these
operations. The user usually does not need to do this manually. As long as the
dataset is not repeated forever, the `tf.data` API will raise an end-of-input
exception automatically after the last batch has been produced.
Note: Estimator.predict returns a Python generator. Please consume all the
data from the generator so that TPUEstimator can shutdown the TPU system
properly for user.
Current limitations:
--------------------
1. TPU prediction only works on a single host (one TPU worker).
2. `input_fn` must return a `Dataset` instance rather than `features`. In
fact, .train() and .evaluate() also support Dataset as return value.
Example (MNIST):
----------------
```
height = 32
width = 32
total_examples = 100
def predict_input_fn(params):
batch_size = params['batch_size']
images = tf.random_uniform(
[total_examples, height, width, 3], minval=-1, maxval=1)
dataset = tf.data.Dataset.from_tensor_slices(images)
dataset = dataset.map(lambda images: {'image': images})
dataset = dataset.batch(batch_size)
return dataset
def model_fn(features, labels, params, mode):
# Generate predictions, called 'output', from features['image']
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions={
'predictions': output,
'is_padding': features['is_padding']
})
tpu_est = TPUEstimator(
model_fn=model_fn,
...,
predict_batch_size=16)
# Fully consume the generator so that TPUEstimator can shutdown the TPU
# system.
for item in tpu_est.predict(input_fn=input_fn):
# Filter out item if the `is_padding` is 1.
# Process the 'predictions'
```
Exporting
=========
`export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
and another with `tag_constants.SERVING` and `tag_constants.TPU`.
At serving time, these tags are used to select metagraph to load.
Before running the graph on TPU, TPU system needs to be initialized. If
TensorFlow Serving model-server is used, this is done automatically. If
not, please call `session.run(tpu.initialize_system())`.
`tpu.outside_compilation` can be used to wrap TPU incompatible ops in
`model_fn`.
Example:
----------------
```
def model_fn(features, labels, mode, config, params):
...
logits = ...
export_outputs = {
'logits': export_output_lib.PredictOutput(
{'logits': logits})
}
def host_call(logits):
class_ids = math_ops.argmax(logits)
classes = string_ops.as_string(class_ids)
export_outputs['classes'] =
export_output_lib.ClassificationOutput(classes=classes)
tpu.outside_compilation(host_call, logits)
...
```
"""
def __init__(self,
model_fn=None,
train_cache_fn=None,
eval_cache_fn=None,
model_dir=None,
config=None,
params=None,
use_tpu=True,
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
export_to_tpu=True,
warm_start_from=None):
"""Constructs an `TPUEstimator` instance.
Args:
model_fn: Model function as required by `Estimator` which returns
EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
and `prediction_hooks` must not capure any TPU Tensor inside the model_fn.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
params: An optional `dict` of hyper parameters that will be passed into
`input_fn` and `model_fn`. Keys are names of parameters, values are
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently,
- TPU training and evaluation respect this bit, but eval_on_tpu can
override execution of eval. See below.
- Predict still happens on CPU.
train_batch_size: An int representing the global training batch size.
TPUEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`.
Must be divisible by total number of replicas.
eval_batch_size: An int representing evaluation batch size.
Must be divisible by total number of replicas.
predict_batch_size: An int representing the prediction batch size.
Must be divisible by total number of replicas.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards. For example, if your input_fn produced (images, labels)
where the images tensor is in `HWCN` format, your shard dimensions would
be [3, 0], where 3 corresponds to the `N` dimension of your images
Tensor, and 0 corresponds to the dimension along which to split the
labels to match up with the corresponding images. If None is supplied,
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
False or `PER_HOST_V2`, batch_axis is ignored.
eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
serving on TPU besides the one on CPU.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
filepath is provided instead of a `WarmStartSettings`,
then all variables are warm-started, and it is assumed
that vocabularies and Tensor names are unchanged.
Raises:
ValueError: `params` has reserved keys already.
"""
if config is None or not isinstance(config, tpu_config.RunConfig):
raise ValueError(
'`config` must be provided with type `tpu_config.RunConfig`')
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
raise ValueError('{} are reserved keys but existed in params {}.'.format(
_RESERVED_PARAMS_KEYS, params))
if use_tpu:
# Perform some very basic validations. More validations will be found in
# _InternalTPUContext.
if train_batch_size is None:
raise ValueError('`train_batch_size` cannot be `None`')
util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')
if eval_batch_size is not None:
util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')
if predict_batch_size is not None:
util_lib.check_positive_integer(predict_batch_size,
'predict_batch_size')
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for
# config.model_dir.
model_function = self._augment_model_fn(
model_fn,
train_cache_fn,
eval_cache_fn,
batch_axis)
# Overwrite log_step_count_steps to disable TensorLoggingHook and
# StepCounterHook from being created in Estimator. TPUEstimator already
# added equivalent hooks in _augment_model_fn above.
self._log_every_n_steps = config.log_step_count_steps
config = config.replace(log_step_count_steps=None)
# Passing non-None params as wrapped model_fn has it.
params = params or {}
super(TPUEstimator, self).__init__(
model_fn=model_function,
model_dir=model_dir,
config=config,
params=params,
warm_start_from=warm_start_from)
self._iterations_per_training_loop = (
self._config.tpu_config.iterations_per_loop)
# All properties passed to _InternalTPUContext are immutable.
# pylint: disable=protected-access
self._ctx = tpu_context._get_tpu_context(
self._config, train_batch_size,
eval_batch_size, predict_batch_size,
use_tpu,
eval_on_tpu)
self._export_to_tpu = export_to_tpu
self._is_input_fn_invoked = None
self._rendezvous = {}
def _add_meta_graph_for_mode(self,
builder,
input_receiver_fn_map,
checkpoint_path,
strip_default_attrs,
save_variables=True,
mode=model_fn_lib.ModeKeys.PREDICT,
export_tags=None,
check_variables=True):
if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
raise NotImplementedError(
'TPUEstimator only handles mode PREDICT for exporting '
'when `export_to_tpu` is `True`; '
'got {}.'.format(mode))
(super(TPUEstimator, self).
_add_meta_graph_for_mode(builder,
input_receiver_fn_map,
checkpoint_path,
strip_default_attrs,
save_variables,
mode=mode,
export_tags=export_tags,
check_variables=check_variables))
if self._export_to_tpu:
input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
input_receiver_fn_map[mode]}
export_tags = [tag_constants.SERVING, tag_constants.TPU]
mode = _REWRITE_FOR_INFERENCE_MODE
# See b/110052256 for why `check_variables` is `False`.
(super(TPUEstimator, self).
_add_meta_graph_for_mode(builder,
input_receiver_fn_map,
checkpoint_path,
strip_default_attrs,
save_variables=False,
mode=mode,
export_tags=export_tags,
check_variables=False))
def _call_model_fn(self, features, labels, mode, config):
if mode == _REWRITE_FOR_INFERENCE_MODE:
return self._call_model_fn_for_inference(features, labels, mode, config)
else:
return super(TPUEstimator, self)._call_model_fn(
features, labels, mode, config)
def _call_model_fn_for_inference(self, features, labels, mode, config):
"""Wraps `_call_model_fn` for `export_savedmodel`."""
if mode != _REWRITE_FOR_INFERENCE_MODE:
raise ValueError('mode must be {}; '
'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
capture = _CapturedObject()
def computation():
"""Compute tpu tensors used in export_outputs.
Passed to rewrite_for_inference so that model_fn will be called under
the rewriting contexts. Only tpu tensors are returned, but export_outputs
and scaffold are captured.
Returns:
A list of Tensors used in export_outputs and not marked for
outside_compilation.
"""
# We should only call model fn once and it should be inside `computation`
# so that building the graph will happen under `rewrite_for_inference`.
mode = model_fn_lib.ModeKeys.PREDICT
estimator_spec = self._call_model_fn(features, labels, mode, config)
# We pick the TPU tensors out from `export_output` and later return them
# from `computation` for rewriting.
tensors_dict = collections.OrderedDict(
(k, _export_output_to_tensors(v))
for k, v in six.iteritems(estimator_spec.export_outputs)
)
tensors = nest.flatten(tensors_dict)
tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]
# We cannot return anything other than `tpu_tensors` here so we capture
# the rest for later use.
capture.capture((estimator_spec, tensors_dict, tensors))
return tpu_tensors
tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation)
estimator_spec, tensors_dict, tensors = capture.get()
# Reconstruct `tensors`, but with `tpu_tensors` replaced with
# `tpu_tensors_on_cpu`.
new_tensors = []
for t in tensors:
if _is_tpu_tensor(t):
new_tensors.append(tpu_tensors_on_cpu.pop(0))
elif t is None:
new_tensors.append(None)
else:
# Only fetching `tpu_tensors_on_cpu` does not trigger
# TPU computation and blocks, so we add the control dependency here.
control_inputs = (tpu_tensors_on_cpu
if isinstance(tpu_tensors_on_cpu, (list, tuple))
else (tpu_tensors_on_cpu,))
with ops.control_dependencies(control_inputs):
new_tensors.append(array_ops.identity(t))
# Reconstruct `tensors_dict`.
new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
# Reconstruct `export_outputs`.
export_outputs = estimator_spec.export_outputs
new_export_outputs = collections.OrderedDict(
(k, _clone_export_output_with_tensors(export_outputs[k], v))
for k, v in six.iteritems(new_tensors_dict)
)
return estimator_spec._replace(export_outputs=new_export_outputs)
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
Args:
graph: The graph in which to create the global step.
Returns:
A global step `Tensor`.
Raises:
ValueError: if the global step tensor is already defined.
"""
return _create_global_step(graph)
def _convert_train_steps_to_hooks(self, steps, max_steps):
with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx:
if ctx.is_running_on_cpu():
return super(TPUEstimator, self)._convert_train_steps_to_hooks(
steps, max_steps)
# On TPU.
if steps is None and max_steps is None:
raise ValueError(
'For TPU training, one of `steps` or `max_steps` must be set. '
'Cannot be both `None`.')
# Estimator.train has explicit positiveness check.
if steps is not None:
util_lib.check_positive_integer(steps, 'Train steps')
if max_steps is not None:
util_lib.check_positive_integer(max_steps, 'Train max_steps')
return [
_TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
]
def _convert_eval_steps_to_hooks(self, steps):
with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
if ctx.is_running_on_cpu():
return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps)
if steps is None:
raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.')
util_lib.check_positive_integer(steps, 'Eval steps')
return [
evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access
num_evals=steps),
_SetEvalIterationsHook(steps)
]
def _call_input_fn(self, input_fn, mode):
"""Calls the input function.
Args:
input_fn: The input function.
mode: ModeKeys
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
input_fn_args = function_utils.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params # a deep copy.
else:
raise ValueError('input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'.format(input_fn))
if 'config' in input_fn_args:
kwargs['config'] = config
if 'mode' in input_fn_args:
kwargs['mode'] = mode
# Records the fact input_fn has been invoked.
self._is_input_fn_invoked = True
with self._ctx.with_mode(mode) as ctx:
# Setting the batch size in params first. This helps user to have same
# input_fn for use_tpu=True/False.
batch_size_for_input_fn = ctx.batch_size_for_input_fn
if batch_size_for_input_fn is not None:
_add_item_to_params(kwargs['params'],
_BATCH_SIZE_KEY, batch_size_for_input_fn)
# For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.
if ctx.is_running_on_cpu(is_export_mode=False):
with ops.DEVICE('/DEVICE:CPU:0'):
return input_fn(**kwargs)
# For TPU computation, input_fn should be invoked in a tf.while_loop for
# performance. While constructing the tf.while_loop, the structure of
# inputs returned by the `input_fn` needs to be recorded. The structure
# includes whether features or labels is dict or single Tensor, dict keys,
# tensor shapes, and dtypes. The recorded structure is used to create the
# infeed dequeue ops, which must be wrapped and passed as a Fn, called
# inside the TPU computation, as the TPU computation is wrapped inside a
# tf.while_loop also. So, we either pass input_fn to model_fn or pass
# dequeue_fn to model_fn. Here, `input_fn` is passed directly as
# `features` in `model_fn` signature.
def _input_fn(ctx):
_add_item_to_params(kwargs['params'], _CTX_KEY, ctx)
return input_fn(**kwargs)
return _input_fn
def _validate_features_in_predict_input(self, result):
"""Skip the validation.
For TPUEstimator, we do not need to check the result type. `_InputPipeline`
has stronger check. Parent class's check generates confusing warning msg.
Args:
result: `features` returned by input_fn.
"""
pass
def train(self,
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
try:
return super(TPUEstimator, self).train(
input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps,
saving_listeners=saving_listeners
)
except Exception: # pylint: disable=broad-except
rendezvous.record_error('training_loop', sys.exc_info())
finally:
rendezvous.record_done('training_loop')
rendezvous.raise_errors()
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
try:
return super(TPUEstimator, self).evaluate(
input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path,
name=name
)
except Exception: # pylint: disable=broad-except
rendezvous.record_error('evaluation_loop', sys.exc_info())
finally:
rendezvous.record_done('evaluation_loop')
rendezvous.raise_errors()
def predict(self,
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
try:
for result in super(TPUEstimator, self).predict(
input_fn=input_fn,
predict_keys=predict_keys,
hooks=hooks,
checkpoint_path=checkpoint_path,
yield_single_examples=yield_single_examples):
yield result
except Exception: # pylint: disable=broad-except
rendezvous.record_error('prediction_loop', sys.exc_info())
finally:
rendezvous.record_done('prediction_loop')
rendezvous.raise_errors()
rendezvous.record_done('prediction_loop')
rendezvous.raise_errors()
def _augment_model_fn(self, model_fn, train_cache_fn, eval_cache_fn, batch_axis):
"""Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params):
"""A Estimator `model_fn` for TPUEstimator."""
with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, train_cache_fn,
eval_cache_fn, config, params, ctx)
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
# but not in `export_savedmodel()`.
if self._is_input_fn_invoked:
is_export_mode = False
else:
is_export_mode = True
# Clear the bit.
self._is_input_fn_invoked = None
# examples_hook is added to training_hooks for both CPU and TPU
# execution.
examples_hook = ExamplesPerSecondHook(
ctx.global_batch_size,
output_dir=self.model_dir,
every_n_steps=self._log_every_n_steps)
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks + (examples_hook,))
return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
assert callable(features), '`input_fn` is not callable.'
input_fn = features
input_holders = _InputPipeline(input_fn, batch_axis, ctx)
enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
graph = ops.get_default_graph()
for enqueue_op in enqueue_ops:
if isinstance(enqueue_op, list):
graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
else:
graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
if mode == model_fn_lib.ModeKeys.TRAIN:
loss, host_call, scaffold, training_hooks = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
if model_fn_wrapper._params.get('track_mean', False):
iterations_per_loop_var = _create_or_get_iterations_per_loop()
loss = math_ops.div(loss,
math_ops.cast(
iterations_per_loop_var,
dtype=loss.dtype))
host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
host_ops = []
shutdown_hooks = []
shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
'shutdown_worker')
if shutdown_mode:
if shutdown_mode == 'shutdown_worker':
finalizer_hooks = [
session_support.ShutdownLameWorkers(timeout_ms=60*1000),
]
elif shutdown_mode == 'shutdown_computation':
finalizer_hooks = [
session_support.RestartComputation(timeout_ms=60*1000),
]
else:
raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
shutdown_mode)
shutdown_hooks.append(session_support.GracefulShutdownHook(
checkpoint_prefix=self.model_dir + '/model.ckpt',
on_shutdown_hooks=finalizer_hooks
))
with ops.control_dependencies([loss]):
global_step = array_ops.identity(training.get_global_step())
hooks = input_hooks + shutdown_hooks
logging_hook_frequency = ( # Divide and round up
(self._log_every_n_steps +
self._config.tpu_config.iterations_per_loop - 1) //
self._config.tpu_config.iterations_per_loop)
iterations_per_loop = array_ops.identity(
_create_or_get_iterations_per_loop())
hooks.extend([
TPUInfeedOutfeedSessionHook(
ctx,
enqueue_ops,
host_ops,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode],
),
InstallSignalHandlerHook(),
training.LoggingTensorHook(
{
'loss': array_ops.identity(loss),
'ppl': tensorflow.exp(loss),
'bpc': loss / tensorflow.constant(math.log(2)),
'#iter/loop': iterations_per_loop,
'global step': global_step,
},
every_n_iter=logging_hook_frequency)
])
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
if training_hooks:
hooks.extend(training_hooks)
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
checkpoint_hook = training.CheckpointSaverHook(
self.model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold)
checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
chief_hooks.append(checkpoint_hook)
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
update_ops = _sync_variables_ops()
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
train_op = control_flow_ops.group(*update_ops)
graph.add_to_collection(_TPU_TRAIN_OP, train_op)
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
training_chief_hooks=chief_hooks,
training_hooks=hooks,
train_op=train_op,
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL:
total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
mean_loss = math_ops.div(total_loss,
math_ops.cast(
iterations_per_loop_var,
dtype=total_loss.dtype))
# Creates a dummy metric update_op for all metrics. Estimator expects
# all metrics in eval_metric_ops have update_op and calls them one by
# one. The real metric update_ops are invoked in a separated thread.
# So, here give Estimator the dummy op for all metrics.
with ops.control_dependencies([mean_loss]):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
internal_ops_to_run = _sync_variables_ops()
internal_ops_to_run.append(
_increase_eval_step_op(iterations_per_loop_var))
with ops.control_dependencies(internal_ops_to_run):
dummy_update_op = control_flow_ops.no_op()
host_call_ret = host_calls.create_tpu_hostcall()
eval_metric_ops = {}
eval_update_ops = []
for k, v in host_call_ret.get('eval_metrics', {}).items():
eval_metric_ops[k] = (v[0], dummy_update_op)
eval_update_ops.append(v[1])
if 'host_call' not in host_call_ret:
host_ops = []
else:
host_ops = host_call_ret['host_call']
hooks = [
TPUInfeedOutfeedSessionHook(
ctx,
enqueue_ops,
eval_update_ops + host_ops,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode]),
] + input_hooks
if eval_hooks:
hooks.extend(eval_hooks)
return model_fn_lib.EstimatorSpec(
mode,
loss=mean_loss,
evaluation_hooks=hooks,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
# Predict
assert mode == model_fn_lib.ModeKeys.PREDICT
(dummy_predict_op, host_calls,
scaffold, prediction_hooks) = _predict_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
internal_ops_to_run = _sync_variables_ops()
with ops.control_dependencies(internal_ops_to_run):
dummy_predict_op = control_flow_ops.no_op()
# In train and evaluation, the main TPU program is passed to monitored
# training session to run. Infeed enqueue and outfeed dequeue are
# executed in side threads. This is not the configuration for
# prediction mode.
#
# For prediction, the Estimator executes the EstimatorSpec.predictions
# directly and yield the element (via generator) to call site. So, the
# outfeed based prediction must be passed to MonitoredSession directly.
# Other parts of the TPU execution are organized as follows.
#
# 1. All outfeed based Tensors must be grouped with predictions Tensors
# to form a single invocation. This avoid the issue we might trigger
# multiple outfeeds incorrectly. To achieve this, `host_call` is
# placed in control_dependencies of `stopping_signals`, and
# `stopping_signals` is passed into _StoppingPredictHook, which sets
# the `stopping_signals` as SessionRunArgs. MonitoredSession merges
# all SessionRunArgs with the fetch in session.run together.
#
# 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
# are grouped together. They will be launched once and only once in
# side threads and they quit naturally according to the SAME stopping
# condition.
enqueue_ops.append(dummy_predict_op)
host_call_ret = host_calls.create_tpu_hostcall()
if 'host_call' not in host_call_ret:
host_ops = []
else:
host_ops = host_call_ret['host_call']
predictions = host_call_ret['predictions']
_verify_cross_hosts_transfer_size(
predictions, message=(
'The estimated size for TPUEstimatorSpec.predictions is too '
'large.'))
signals = host_call_ret['signals']
with ops.control_dependencies(host_ops):
host_ops = [] # Empty, we do do not need it anymore.
scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
signals)
predictions = _PaddingSignals.slice_tensor_or_dict(
predictions, signals)
hooks = [
_StoppingPredictHook(scalar_stopping_signal),
TPUInfeedOutfeedSessionHookForPrediction(
ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]),
] + input_hooks
if prediction_hooks:
hooks.extend(prediction_hooks)
return model_fn_lib.EstimatorSpec(
mode,
prediction_hooks=hooks,
predictions=predictions,
scaffold=scaffold)
return _model_fn
def _is_tpu_tensor(tensor):
if not isinstance(tensor, ops.Tensor):
return False
try:
tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access
except ValueError:
return True
else:
return False
def _export_output_to_tensors(export_output):
"""Get a list of `Tensors` used in `export_output`.
Args:
export_output: an `ExportOutput` object such as `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
Returns:
a list of tensors used in export_output.
Raises:
ValueError: if `export_output` is not one of `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
"""
if isinstance(export_output, export_output_lib.ClassificationOutput):
return [export_output.scores, export_output.classes]
elif isinstance(export_output, export_output_lib.RegressionOutput):
return [export_output.value]
elif isinstance(export_output, export_output_lib.PredictOutput):
return export_output.outputs.values()
else:
raise ValueError(
'`export_output` must be have type `ClassificationOutput`, '
'`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
def _clone_export_output_with_tensors(export_output, tensors):
"""Clones `export_output` but with new `tensors`.
Args:
export_output: an `ExportOutput` object such as `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
tensors: a list of `Tensors` used to construct a new `export_output`.
Returns:
A dict similar to `export_output` but with `tensors`.
Raises:
ValueError: if `export_output` is not one of `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
"""
if isinstance(export_output, export_output_lib.ClassificationOutput):
if len(tensors) != 2:
raise ValueError('tensors must be of length 2; '
'got {}.'.format(len(tensors)))
return export_output_lib.ClassificationOutput(*tensors)
elif isinstance(export_output, export_output_lib.RegressionOutput):
if len(tensors) != 1:
raise ValueError('tensors must be of length 1; '
'got {}'.format(len(tensors)))
return export_output_lib.RegressionOutput(*tensors)
elif isinstance(export_output, export_output_lib.PredictOutput):
return export_output_lib.PredictOutput(
dict(zip(export_output.outputs.keys(), tensors)))
else:
raise ValueError(
'`export_output` must be have type `ClassificationOutput`, '
'`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
(single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)
def multi_tpu_eval_steps_on_single_shard():
loop_vars = [_ZERO_LOSS]
if model_fn_wrapper._eval_cache_fn is not None:
batch_size = ctx.global_batch_size
num_shards = ctx._config._tpu_config.num_shards
loop_vars += model_fn_wrapper._eval_cache_fn(batch_size // num_shards)
return training_loop.repeat(
iterations_per_loop_var,
single_tpu_eval_step,
loop_vars)
ret = tpu.shard(
multi_tpu_eval_steps_on_single_shard,
inputs=[],
num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
loss = ret[0]
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, host_calls, scaffold, captured_eval_hooks.get()
def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
(single_tpu_train_step, host_call, captured_scaffold_fn,
captured_training_hooks) = (
model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
def multi_tpu_train_steps_on_single_shard():
if model_fn_wrapper._params.get('track_mean', False):
loop_vars = [_ZERO_LOSS]
else:
loop_vars = [_INITIAL_LOSS]
if model_fn_wrapper._train_cache_fn is not None:
batch_size = ctx.global_batch_size
num_shards = ctx._config._tpu_config.num_shards
loop_vars += model_fn_wrapper._train_cache_fn(batch_size // num_shards)
return training_loop.repeat(
iterations_per_loop_var,
single_tpu_train_step,
loop_vars)
ret = tpu.shard(
multi_tpu_train_steps_on_single_shard,
inputs=[],
num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
loss = ret[0]
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, host_call, scaffold, captured_training_hooks.get()
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
(single_tpu_predict_step, host_calls, captured_scaffold_fn,
captured_predict_hooks
) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
def multi_tpu_predict_steps_on_single_shard():
def cond(scalar_stopping_signal):
return math_ops.logical_not(
_StopSignals.should_stop(scalar_stopping_signal))
inputs = [_StopSignals.NON_STOPPING_SIGNAL]
outputs = training_loop.while_loop(
cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
return outputs
(dummy_predict_op,) = tpu.shard(
multi_tpu_predict_steps_on_single_shard,
inputs=[],
num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get()
def _wrap_computation_in_while_loop(device, op_fn):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
def computation(i):
with ops.control_dependencies(op_fn()):
return i + 1
iterations_per_loop_var = _create_or_get_iterations_per_loop()
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with ops.DEVICE(device):
iterations = array_ops.identity(iterations_per_loop_var)
return control_flow_ops.while_loop(
lambda i: i < iterations,
computation, [constant_op.constant(0)],
parallel_iterations=1)
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
def cond(scalar_stopping_signal):
return math_ops.logical_not(
_StopSignals.should_stop(scalar_stopping_signal))
def computation(unused_scalar_stopping_signal):
return_value = op_fn()
execute_ops = return_value['ops']
signals = return_value['signals']
with ops.control_dependencies(execute_ops):
return _StopSignals.as_scalar_stopping_signal(signals)
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with ops.DEVICE(device):
return control_flow_ops.while_loop(
cond,
computation, [_StopSignals.NON_STOPPING_SIGNAL],
parallel_iterations=1)
def _validate_tpu_training_graph():
"""Validate graph before running distributed training.
Raises:
ValueError: If the graph seems invalid for running on DEVICE
"""
operations = ops.get_default_graph().get_operations()
# Check if there is atleast one CrossReplicaSum operation in the graph
# This should be introduced by using the CrossShardOptimizer wrapper
cross_replica_sum_ops = [
o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
]
if not cross_replica_sum_ops:
raise ValueError(
'CrossShardOptimizer must be used for model training on TPUs.')
class _CapturedObject(object):
"""A placeholder to capture an object.
This is useful when we need to capture a Python object in the Tensorflow
control flow body function and use it outside the control flow.
"""
def __init__(self):
self._object = None
self._captured = False
def capture(self, o):
if self._captured:
raise RuntimeError(
'InternalError: Object can capture only once. Please file bug.')
self._captured = True
self._object = o
def get(self):
if not self._captured:
raise RuntimeError(
'InternalError: Object is not captured properly before `get`. '
'Please file bug.')
return self._object
def _get_scaffold(captured_scaffold_fn):
"""Retrieves the Scaffold from `captured_scaffold_fn`."""
with _CapturingContext(message='Inside scaffold_fn'):
scaffold_fn = captured_scaffold_fn.get()
if scaffold_fn:
scaffold = scaffold_fn()
if scaffold is None:
raise ValueError(
'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
else:
scaffold = None
if scaffold:
wrapped_finalize = scaffold.finalize
def _finalize():
with _CapturingContext('Inside Scaffold.finalize'):
wrapped_finalize()
scaffold.finalize = _finalize
return scaffold
class _CapturingContext(control_flow_ops.ControlFlowContext):
"""Tracks references to Tensors defined in TPU replication."""
def __init__(self, message):
control_flow_ops.ControlFlowContext.__init__(self)
self._message = message
def AddOp(self, op): # pylint: disable=invalid-name
for c in op.inputs:
if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access
raise ValueError('{}: Op {} depends on TPU computation {}, '
'which is not allowed.'.format(self._message, op, c))
def to_control_flow_context_def(self, context_def, export_scope=None):
# pylint: disable=useless-super-delegation
# NOTE(slebedev): the method is required by `ControlFlowContext`.
super(_CapturingContext, self).to_control_flow_context_def(
context_def, export_scope)
def __enter__(self):
# pylint: disable=protected-access
self._g = ops.get_default_graph()
self._old = self._g._get_control_flow_context()
self._g._set_control_flow_context(self)
# pylint: enable=protected-access
def __exit__(self, _, __, ___): # pylint: disable=invalid-name
self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
class _Inputs(object):
"""A data structure representing the input_fn returned values.
This also supports the returned value from input_fn as `Dataset`.
"""
def __init__(self, features=None, labels=None, dataset=None, signals=None):
if dataset is not None and (features is not None or labels is not None or
signals is not None):
raise RuntimeError('Internal Error: Either (features and labels) or '
'dataset should be provided, not both. Please file '
'bug')
self._features = features
self._labels = labels
self._signals = signals
self._dataset = dataset
self._iterator = None
@staticmethod
def from_input_fn(return_values):
"""Returns an `_Inputs` instance according to `input_fn` return value."""
if isinstance(return_values, dataset_ops.Dataset):
dataset = return_values
return _Inputs(dataset=dataset)
features, labels = _Inputs._parse_inputs(return_values)
return _Inputs(features, labels)
@staticmethod
def _parse_inputs(return_values):
if isinstance(return_values, tuple):
features, labels = return_values
else:
features, labels = return_values, None
return features, labels
@property
def is_dataset(self):
"""Returns True if the return value from input_fn is Dataset."""
return self._dataset is not None
def dataset_initializer_hook(self):
"""Returns a `SessionRunHook` to initialize this dataset.
This must be called before `features_and_labels`.
"""
iterator = self._dataset.make_initializable_iterator()
# pylint: disable=protected-access
hook = estimator_util._DatasetInitializerHook(iterator)
# pylint: enable=protected-access
self._iterator = iterator
return hook
def features_and_labels(self):
"""Gets `features` and `labels`."""
if self.is_dataset:
if self._iterator is None:
raise RuntimeError('Internal error: Must call dataset_initializer_hook '
'before calling features_and_labels(). Please file '
'a bug!')
return _Inputs._parse_inputs(self._iterator.get_next())
return (self._features, self._labels)
def signals(self):
return self._signals
@property
def dataset(self):
return self._dataset
class _InputsWithStoppingSignals(_Inputs):
"""Inputs with `_StopSignals` inserted into the dataset."""
def __init__(self,
dataset,
batch_size,
add_padding=False,
num_invocations_per_step=1):
assert dataset is not None
user_provided_dataset = dataset.map(
_InputsWithStoppingSignals.insert_stopping_signal(
stop=False, batch_size=batch_size, add_padding=add_padding))
if num_invocations_per_step == 1:
final_batch_dataset = dataset.take(1).map(
_InputsWithStoppingSignals.insert_stopping_signal(
stop=True, batch_size=batch_size, add_padding=add_padding))
else:
# We append (2 * num_invocations_per_step - 1) batches for exhausting the
# user_provided_dataset and stop properly.
# For example, if num_invocations_per_step is 2, we append 3 additional
# padding batches: b1, b2, b3.
# If user_provided_dataset contains two batches: a1, a2
# Step 1: [a1, a2]
# Step 2: [b1, b2] -> STOP
# If user_provided_dataset contains three batches: a1, a2, a3.
# The training loops:
# Step 1: [a1, a2]
# Step 2: [a3, b1]
# Step 3: [b2, b3] -> STOP.
final_batch_dataset = dataset.take(1).map(
_InputsWithStoppingSignals.insert_stopping_signal(
stop=True, batch_size=batch_size, add_padding=add_padding))
final_batch_dataset = final_batch_dataset.repeat(
2 * num_invocations_per_step - 1)
def _set_mask(data_dict):
signals = data_dict['signals']
signals['padding_mask'] = array_ops.ones_like(signals['padding_mask'])
data_dict['signals'] = signals
return data_dict
# Mask out the extra batch.
final_batch_dataset = final_batch_dataset.map(_set_mask)
dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
self._current_inputs = None
def features_and_labels(self):
if self._current_inputs is not None:
raise RuntimeError(
'Internal Error: The previous inputs have not been properly '
'consumed. First call features_and_labels, then call signals.')
inputs_with_signals = self._iterator.get_next()
features = inputs_with_signals['features']
labels = inputs_with_signals.get('labels')
self._current_inputs = inputs_with_signals
return features, labels
def signals(self):
"""Returns the `Signals` from `_Inputs`."""
if self._current_inputs is None:
raise RuntimeError(
'Internal Error: The current inputs have not been properly '
'generated. First call features_and_labels, then call signals.')
signals = self._current_inputs['signals']
self._current_inputs = None
return signals
@staticmethod
def insert_stopping_signal(stop, batch_size, add_padding=False):
"""Inserts stopping_signal into dataset via _map_fn.
Here we change the data structure in the dataset, such that the return value
is a dictionary now and `features`, `labels`, and `signals` are three
distinguished keys in that dict. This provides a better structure, which
eases the process to decompose the inputs (see `features_and_labels`).
Args:
stop: bool, state of current stopping signals.
batch_size: int, batch size.
add_padding: bool, whether to pad the tensor to full batch size.
Returns:
A map_fn passed to dataset.map API.
"""
def _map_fn(*args):
"""The map fn to insert signals."""
if len(args) == 1:
# Unpack the single Tensor/dict argument as features. This is required
# for the input_fn returns no labels.
args = args[0]
features, labels = _Inputs._parse_inputs(args)
new_input_dict = {}
if add_padding:
padding_mask, features, labels = (
_PaddingSignals.pad_features_and_labels(
features, labels, batch_size))
new_input_dict['features'] = features
if labels is not None:
new_input_dict['labels'] = labels
else:
new_input_dict['features'] = features
if labels is not None:
new_input_dict['labels'] = labels
padding_mask = None
new_input_dict['signals'] = _StopSignals(
stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict()
return new_input_dict
return _map_fn
class _StopSignals(object):
"""Signals class holding all logic to handle TPU stopping condition."""
NON_STOPPING_SIGNAL = False
STOPPING_SIGNAL = True
def __init__(self, stop, batch_size, padding_mask=None):
self._stop = stop
self._batch_size = batch_size
self._padding_mask = padding_mask
def as_dict(self):
"""Returns the signals as Python dict."""
shape = [self._batch_size, 1]
dtype = dtypes.bool
if self._stop:
stopping = array_ops.ones(shape=shape, dtype=dtype)
else:
stopping = array_ops.zeros(shape=shape, dtype=dtype)
signals = {'stopping': stopping}
if self._padding_mask is not None:
signals['padding_mask'] = self._padding_mask
return signals
@staticmethod
def as_scalar_stopping_signal(signals):
return array_ops.identity(signals['stopping'][0][0])
@staticmethod
def should_stop(scalar_stopping_signal):
"""Detects whether scalar_stopping_signal indicates stopping."""
if isinstance(scalar_stopping_signal, ops.Tensor):
# STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
# way to express the bool check whether scalar_stopping_signal is True.
return math_ops.logical_and(
scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
else:
# For non Tensor case, it is used in SessionRunHook. So, we cannot modify
# the graph anymore. Here, we use pure Python.
return bool(scalar_stopping_signal)
class _PaddingSignals(object):
"""Signals class holding all logic to handle padding."""
@staticmethod
def pad_features_and_labels(features, labels, batch_size):
"""Pads out the batch dimension of features and labels."""
real_batch_size = array_ops.shape(
_PaddingSignals._find_any_tensor(features))[0]
batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)
check_greater = check_ops.assert_greater_equal(
batch_size_tensor, real_batch_size,
data=(batch_size_tensor, real_batch_size),
message='The real batch size should not be greater than batch_size.')
with ops.control_dependencies([check_greater]):
missing_count = batch_size_tensor - real_batch_size
def pad_single_tensor(tensor):
"""Pads out the batch dimension of a tensor to the complete batch_size."""
rank = len(tensor.shape)
assert rank > 0
padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
padded_shape = (batch_size,) + tuple(tensor.shape[1:])
padded_tensor = array_ops.pad(tensor, padding)
padded_tensor.set_shape(padded_shape)
return padded_tensor
def nest_pad(tensor_or_dict):
return nest.map_structure(pad_single_tensor, tensor_or_dict)
features = nest_pad(features)
if labels is not None:
labels = nest_pad(labels)
padding_mask = _PaddingSignals._padding_mask(
real_batch_size, missing_count, batch_size)
return padding_mask, features, labels
@staticmethod
def slice_tensor_or_dict(tensor_or_dict, signals):
"""Slice the real Tensors according to padding mask in signals."""
padding_mask = signals['padding_mask']
batch_size = array_ops.shape(padding_mask)[0]
def verify_batch_size(tensor):
check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
with ops.control_dependencies([check_batch_size]):
return array_ops.identity(tensor)
def slice_single_tensor(tensor):
rank = len(tensor.shape)
assert rank > 0
real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
return verify_batch_size(tensor)[0:real_batch_size]
# As we split the Tensors to all TPU cores and concat them back, it is
# important to ensure the real data is placed before padded ones, i.e.,
# order is preserved. By that, the sliced padding mask should have all 0's.
# If this assertion failed, # the slice logic here would not hold.
sliced_padding_mask = slice_single_tensor(padding_mask)
assert_padding_mask = math_ops.equal(
math_ops.reduce_sum(sliced_padding_mask), 0)
with ops.control_dependencies([assert_padding_mask]):
should_stop = _StopSignals.should_stop(
_StopSignals.as_scalar_stopping_signal(signals))
is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)
def slice_fn(tensor):
# If the current batch is full batch or part of stopping signals, we do
# not need to slice to save performance.
return control_flow_ops.cond(
math_ops.logical_or(should_stop, is_full_batch),
(lambda: verify_batch_size(tensor)),
(lambda: slice_single_tensor(tensor)))
return nest.map_structure(slice_fn, tensor_or_dict)
@staticmethod
def _find_any_tensor(batch_features):
tensors = [x for x in nest.flatten(batch_features)
if isinstance(x, ops.Tensor)]
if not tensors:
raise ValueError('Cannot find any Tensor in features dict.')
return tensors[0]
@staticmethod
def _padding_mask(real_batch_size, missing_count, batch_size):
padding_mask = array_ops.concat(
[
array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
array_ops.ones((missing_count,), dtype=dtypes.int32)
],
axis=0)
padding_mask.set_shape((batch_size,))
return padding_mask
def _verify_cross_hosts_transfer_size(tensor_dict, message):
total_size = 0
tensor_structure = {}
for key, tensor in tensor_dict.items():
shape = tensor.shape
size = np.product(shape) * tensor.dtype.size
tensor_structure[key] = shape
total_size += size
if total_size >= _ONE_GIGABYTE:
raise ValueError(
'{} The transfer size is larger than the protobuf limit. Please '
'consider to use Tensors with smaller shapes or reduce batch '
'size. Given:\n'
'{}'.format(message, '\n'.join([
' -- Key: {}, Shape: {}'.format(k, v)
for k, v in tensor_structure.items()])))
def _add_item_to_params(params, key, value):
"""Adds a new item into `params`."""
if isinstance(params, hparam.HParams):
# For HParams, we need to use special API.
if key in params:
params.set_hparam(key, value)
else:
params.add_hparam(key, value)
else:
# Now params is Python dict.
params[key] = value
def export_estimator_savedmodel(estimator,
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False):
"""Export `Estimator` trained model for TPU inference.
Args:
estimator: `Estimator` with which model has been trained.
export_dir_base: A string containing a directory in which to create
timestamped subdirectories containing exported SavedModels.
serving_input_receiver_fn: A function that takes no argument and
returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel, or `None` if no extra assets are needed.
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs.
Returns:
The string path to the exported directory.
"""
# `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
# `estimator.config`.
config = tpu_config.RunConfig(model_dir=estimator.model_dir)
est = TPUEstimator(
estimator._model_fn, # pylint: disable=protected-access
config=config,
params=estimator.params,
use_tpu=True,
train_batch_size=2048, # Does not matter.
eval_batch_size=2048, # Does not matter.
)
return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
assets_extra,
as_text,
checkpoint_path,
strip_default_attrs)