TensorFlow Serving C++ API Documentation
tflite_session.h
1 /* Copyright 2019 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_SESSION_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_SESSION_H_
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/base/thread_annotations.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/file_system.h"
29 #include "tensorflow/core/platform/threadpool.h"
30 #include "tensorflow/core/platform/threadpool_options.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tensorflow/lite/external_cpu_backend_context.h"
33 #include "tensorflow/lite/interpreter.h"
34 #include "tensorflow/lite/kernels/cpu_backend_context.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow_serving/batching/threadsafe_status.h"
37 #include "tensorflow_serving/servables/tensorflow/serving_session.h"
38 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
39 
40 namespace tensorflow {
41 namespace serving {
42 
43 using TensorInfoMap = std::map<string, std::pair<TensorInfo, int>>;
44 
45 // Encapsulates a unit of work for BatchScheduler.
46 class TfLiteBatchTask : public BatchTask {
47  public:
48  // Creates a batch task.
49  static void CreateTfLiteBatchTask(
50  const std::vector<string>* output_tensor_names,
51  std::vector<Tensor>* outputs, Notification* done, Status* status,
52  std::unique_ptr<TfLiteBatchTask>* batch_task) {
53  TfLiteBatchTask* task = new TfLiteBatchTask();
54  task->is_partial = false;
55  task->output_tensor_names = output_tensor_names;
56  task->outputs = outputs;
57  task->done = done;
58  task->status = status;
59  batch_task->reset(task);
60  }
61 
62  // Create partial batch task.
63  static void CreatePartialTfLiteBatchTask(
64  std::vector<int> input_indices,
65  const std::vector<string>* output_tensor_names,
66  std::vector<Tensor>* outputs, std::function<void()> done_callback,
67  ThreadSafeStatus* partial_status,
68  std::unique_ptr<TfLiteBatchTask>* batch_task) {
69  TfLiteBatchTask* task = new TfLiteBatchTask();
70  task->is_partial = true;
71  task->input_indices = input_indices;
72  task->output_tensor_names = output_tensor_names;
73  task->outputs = outputs;
74  task->done_callback = done_callback;
75  task->partial_status = partial_status;
76  batch_task->reset(task);
77  }
78 
79  TfLiteBatchTask() : enqueue_time_micros(Env::Default()->NowMicros()) {}
80 
81  TfLiteBatchTask(const TfLiteBatchTask&) = delete;
82 
83  TfLiteBatchTask& operator=(const TfLiteBatchTask&) = delete;
84 
85  ~TfLiteBatchTask() override = default;
86 
87  // Returns the batch size.
88  size_t size() const override { return inputs[0].dim_size(0); }
89 
90  uint64_t start_time_micros() const { return enqueue_time_micros; }
91 
92  Notification* done;
93 
94  Status* status;
95 
96  // Input indices for the tflite tensors, aligned with inputs.
97  std::vector<int> input_indices;
98 
99  // Vector of input tensors.
100  std::vector<Tensor> inputs;
101 
102  // Pointer to tensor of outputs.
103  std::vector<Tensor>* outputs;
104 
105  void set_output(Tensor t) { outputs->push_back(t); }
106 
107  const std::vector<string>* output_tensor_names;
108 
109  RunOptions run_options;
110 
111  const uint64_t enqueue_time_micros;
112 
113  // Required for partial execution using split batches.
114  bool is_partial = false;
115 
116  // A callback for when the partial task is completed.
117  std::function<void()> done_callback;
118 
119  ThreadSafeStatus* partial_status;
120 };
121 
122 using SchedulerCreator = std::function<Status(
123  const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
124  std::function<void(std::unique_ptr<Batch<TfLiteBatchTask>>)>,
125  std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>*)>;
126 
127 // A session to run inference on a TensorFlow Lite model.
128 //
130  public:
131  // Creates a TfLiteSession object from `buffer` representing serialized
132  // TFLite flatbuffer model. Also returns the SignatureDef map based on
133  // input/outputs to the model.
134  //
135  // run in caller thread allows a worker to run on the parent thread,
136  // which may be desired to increase concurrency at the cost of additional
137  // thread context overhead. Defaults to false.
138  static Status Create(string&& buffer, const SessionOptions& options,
139  int num_pools, int num_interpreters_per_pool,
140  std::unique_ptr<TfLiteSession>* tflite_session,
141  ::google::protobuf::Map<string, SignatureDef>* signatures);
142 
143  static Status CreateDefaultBasicBatchScheduler(
144  const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
145  std::function<void(std::unique_ptr<Batch<TfLiteBatchTask>>)>
146  process_batch_callback,
147  std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>* batch_scheduler);
148 
149  static Status SplitTfLiteInputTask(
150  std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
151  int open_batch_remaining_slot, int max_batch_size,
152  std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks);
153 
154  ~TfLiteSession() override = default;
155 
156  Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
157  const std::vector<string>& output_tensor_names,
158  const std::vector<string>& target_node_names,
159  std::vector<Tensor>* outputs) override;
160 
161  Status Run(const RunOptions& run_options,
162  const std::vector<std::pair<string, Tensor>>& inputs,
163  const std::vector<string>& output_tensor_names,
164  const std::vector<string>& target_node_names,
165  std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
166 
167  Status Run(const RunOptions& run_options,
168  const std::vector<std::pair<string, Tensor>>& inputs,
169  const std::vector<string>& output_tensor_names,
170  const std::vector<string>& target_node_names,
171  std::vector<Tensor>* outputs, RunMetadata* run_metadata,
172  const thread::ThreadPoolOptions& thread_pool_options) override;
173 
174  Status ListDevices(std::vector<DeviceAttributes>* response) override;
175 
176  Status SetScheduler(
177  const SchedulerCreator& scheduler_creator,
178  const BasicBatchScheduler<TfLiteBatchTask>::Options& options);
179 
180  BasicBatchScheduler<TfLiteBatchTask>::Options GetSchedulerOptions() {
181  return scheduler_options_;
182  }
183 
184  private:
186  std::map<string, int>&& input_tensor_to_index,
187  std::map<string, int>&& output_tensor_to_index, string&& buffer,
188  std::unique_ptr<tflite::FlatBufferModel> model,
189  std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool);
190  Status RunInternal(
191  const std::vector<int>& tflite_input_indices,
192  const std::vector<std::vector<const Tensor*>>& merged_inputs,
193  const std::vector<string>& output_tensor_names,
194  std::vector<Tensor>* combined_outputs, int batch_size,
195  int* fixed_batch_size = nullptr);
196  const std::map<string, int> input_tensor_to_index_;
197  const std::map<string, int> output_tensor_to_index_;
198  const string model_serialized_bytes_;
199  const std::unique_ptr<tflite::FlatBufferModel> model_;
200  const std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool_;
201  bool use_fixed_batch_size_;
202  std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>> scheduler_;
203  BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options_;
204  void ProcessBatch(std::unique_ptr<Batch<TfLiteBatchTask>> batch);
205  TF_DISALLOW_COPY_AND_ASSIGN(TfLiteSession);
206 };
207 
208 } // namespace serving
209 } // namespace tensorflow
210 
211 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_SESSION_H_