TensorFlow Serving C++ API Documentation
batching_session.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A library for wrapping a tensorflow session such that Run() calls get
17 // scheduled in batches, using a batch scheduler of your choice.
18 
19 #ifndef TENSORFLOW_SERVING_BATCHING_BATCHING_SESSION_H_
20 #define TENSORFLOW_SERVING_BATCHING_BATCHING_SESSION_H_
21 
22 #include <cstddef>
23 #include <functional>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
31 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
32 #include "tensorflow/core/platform/threadpool_options.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #include "tensorflow/core/protobuf/meta_graph.pb.h"
35 #include "tensorflow/core/public/session.h"
36 #include "tensorflow_serving/batching/batching_options.h"
37 #include "tensorflow_serving/batching/threadsafe_status.h"
38 
39 namespace tensorflow {
40 namespace serving {
41 
42 // The batch scheduler task type used for batching sessions, for use in batch
43 // scheduler template parameters, e.g. BasicBatchScheduler<BatchingSessionTask>.
44 struct BatchingSessionTask;
45 
46 // A function to construct a batch scheduler for BatchingSessionTasks from a
47 // process-batch callback.
48 using BatchingSessionSchedulerCreator = std::function<Status(
49  std::function<void(std::unique_ptr<Batch<BatchingSessionTask>>)>,
50  std::unique_ptr<BatchScheduler<BatchingSessionTask>>*)>;
51 
52 // The signature associated with a Session::Run() call, in terms of input and
53 // output tensor names (with the order in which the tensors are listed factored
54 // out). (Note that 'target_node_names' are not supported in batching sessions.)
56  std::set<string> input_tensors;
57  std::set<string> output_tensors;
58 };
59 
60 // Constructs a TensorSignature for a given SignatureDef.
61 TensorSignature TensorSignatureFromSignatureDef(
62  const SignatureDef& signature_def);
63 
64 // Constructs a TensorSignature for a given set of SignatureDefs. The resulting
65 // TensorSignature represents the Session::Run() arguments that would be used
66 // when issuing a single Run() call that exercises the signature defs jointly.
67 //
68 // For example, say there's a graph that takes 'input' and transforms it into
69 // 'predicted_label' and 'confidence_score'. Suppose SignatureDef 1 requests
70 // only 'predicted_label' as output, and SignatureDef 2 requests only
71 // 'confidence_score'. A joint TensorSignature would feed 'input' and receive
72 // both 'predicted_label' and 'confidence_score' as output, in a single Run()
73 // invocation.
74 TensorSignature TensorSignatureFromSignatureDefs(
75  const std::vector<SignatureDef>& signature_defs);
76 
77 // A signature paired with a lambda to create a batch scheduler for Run() calls
78 // matching the signature.
80  TensorSignature signature;
81  BatchingSessionSchedulerCreator scheduler_creator;
82 };
83 
84 // Options for batching tensorflow Sessions; see the Create*() functions below.
86 
87 // Wraps a session in a new session that automatically batches Run() calls.
88 // Uses one batcher for each distinct Run() signature supported. In addition to
89 // a session to wrap, takes a list of signature/BatchingSessionSchedulerCreator
90 // pairs. (The number of supported signatures is typically small, and often just
91 // a single one.)
92 //
93 // The wrapped session only batches Run() calls that conform to one of the
94 // specified signatures and leave 'target_node_names' empty. Other Run() calls
95 // are executed in-line without batching, and may harm performance. (Extra-
96 // signature Run() support is intended primarily for debugging and diagnostics.)
97 //
98 // For batched calls, it is assumed that the outermost (0th) dimension of each
99 // input and output tensor is the batch-size dimension. All input tensors must
100 // have the same 0th-dimension size B; the produced output tensors are also
101 // assumed to have 0th-dimension size B.
102 //
103 // IMPORTANT: Each call to Session::Run() is synchronous, and blocks waiting for
104 // other Run() calls with the same signature to merge with to form a large
105 // batch. Consequently, to achieve good throughput we recommend setting the
106 // number of client threads that call Session::Run() equal to about twice the
107 // sum over all signatures of the maximum batch size.
108 //
109 // Example usage, for the common case of a single signature:
110 //
111 // BatchingSessionOptions options = ...;
112 // auto scheduler_creator = [schedule_options, retry_options](
113 // std::function<void(std::unique_ptr<Batch<BatchingSessionTask>>)>
114 // process_batch_callback,
115 // std::unique_ptr<BatchScheduler<BatchingSessionTask>>* batch_scheduler) {
116 // std::unique_ptr<BasicBatchScheduler<BatchingSessionTask>> scheduler;
117 // TF_RETURN_IF_ERROR(BasicBatchScheduler<BatchingSessionTask>::Create(
118 // schedule_options, process_batch_callback, &scheduler));
119 // std::unique_ptr<BatchSchedulerRetrier<BatchingSessionTask>> retrier;
120 // TF_RETURN_IF_ERROR(BatchSchedulerRetrier<BatchingSessionTask>::Create(
121 // retry_options, std::move(scheduler), &retrier));
122 // *batch_scheduler = std::move(retrier);
123 // return Status::OK();
124 // };
125 // std::unique_ptr<Session> batching_session;
126 // TF_CHECK_OK(CreateBatchingSession(options, {{signature, scheduler_creator}},
127 // std::move(session), &batching_session));
128 //
129 Status CreateBatchingSession(
130  const BatchingSessionOptions& options,
131  const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
132  signatures_with_scheduler_creators,
133  std::unique_ptr<Session> session,
134  std::unique_ptr<Session>* batching_session);
135 
136 // Same as above but allows for a default scheduler creator for which signatures
137 // that don't match a supplied value during run time can still use batching.
138 Status CreateBatchingSession(
139  const BatchingSessionOptions& options,
140  const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
141  signatures_with_scheduler_creators,
142  BatchingSessionSchedulerCreator default_creator,
143  std::unique_ptr<Session> session,
144  std::unique_ptr<Session>* batching_session);
145 
146 // A convenience for using CreateBatchingSession() to create a
147 // BasicBatchScheduler for a single signature.
148 Status CreateBasicBatchingSession(
149  const typename BasicBatchScheduler<BatchingSessionTask>::Options&
150  schedule_options,
151  const BatchingSessionOptions& batching_session_options,
152  const TensorSignature& signature, std::unique_ptr<Session> session,
153  std::unique_ptr<Session>* batching_session);
154 
155 // The default implementation of
156 // `BasicBatchScheduler::Options.split_input_task_func` if corresponding batch
157 // scheduler for a batching session sets
158 // `BasicBatchScheduler::Options.enable_large_batch_splitting` to true.
159 Status SplitInputTask(
160  std::unique_ptr<BatchingSessionTask>* input_task_ptr,
161  int open_batch_remaining_slot, int max_batch_size,
162  std::vector<std::unique_ptr<BatchingSessionTask>>* output_tasks);
163 
165 // Implementation details follow. API users need not read.
166 
167 struct BatchingSessionTask : public BatchTask {
168  ~BatchingSessionTask() override = default;
169  size_t size() const override { return zeroth_dim_size; }
170 
171  // For monitoring purpose.
172  static std::string Name() { return "batching_session"; }
173 
174  // Fields populated when a task is received.
175  uint64_t enqueue_time_micros;
176  RunOptions run_options;
177  size_t zeroth_dim_size;
178  const std::vector<std::pair<string, Tensor>>* inputs;
179  const std::vector<string>* output_tensor_names;
180 
181  // Fields populated when a task is processed (as part of a batch), and
182  // returned by BatchingSession when a task is complete.
183  Notification* done;
184  Status* status;
185  std::vector<Tensor>* outputs;
186  RunMetadata* run_metadata;
187  absl::optional<thread::ThreadPoolOptions> thread_pool_options;
188 
189  // Fields populated when a task is processed (as part of a batch), and
190  // substantially used in the intermediate stage if a task is a slice of
191  // input task (i.e., is_partial=true).
192  bool is_partial = false;
193  // 'owned_split_inputs' stores pairs of tensor names and input tensors
194  // if 'is_partial' = true.
195  std::unique_ptr<std::vector<std::pair<string, Tensor>>> owned_split_inputs;
196  // The index of this split, along the 0-th dimension of input from op
197  // invocation.
198  int split_index = 0;
199  std::function<void()> done_callback;
200  typedef std::vector<std::vector<Tensor>> TensorMatrix;
201  // For shared_ptr objects, ownership shared by:
202  // 1) each split of task (to fill one row in this matrix)
203  // and
204  // 2) callback that runs to merge output of individual splits for an op
205  // invocation, after all splits complete.
206  // Two-dimensional tensor matrix,
207  std::shared_ptr<TensorMatrix> shared_outputs;
208  // 'status' records error (could be from any split) if at least one split
209  // returns error, OK otherwise.
210  std::shared_ptr<ThreadSafeStatus> thread_safe_status;
211  // 'split_run_metadatas' records `run_metadata` of each split.
212  std::shared_ptr<std::vector<RunMetadata>> split_run_metadatas;
213 };
214 
215 } // namespace serving
216 } // namespace tensorflow
217 
218 #endif // TENSORFLOW_SERVING_BATCHING_BATCHING_SESSION_H_