TensorFlow Serving C++ API Documentation
tflite_interpreter_pool.h
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_INTERPRETER_POOL_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_INTERPRETER_POOL_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/base/thread_annotations.h"
23 #include "absl/synchronization/mutex.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/platform/cpu_info.h"
27 #include "tensorflow/core/platform/tstring.h"
28 #include "tensorflow/core/public/session_options.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/model.h"
31 #ifdef TFLITE_PROFILE
32 #ifndef TFLITE_PROFILE_EVENTS
33 #define TFLITE_PROFILE_EVENTS 2000
34 #endif
35 #include "tensorflow/lite/profiling/buffered_profiler.h"
36 #include "tensorflow/lite/profiling/profile_summarizer.h"
37 #include "tensorflow/lite/profiling/profile_summary_formatter.h"
38 #endif
39 #include "tensorflow/lite/string_util.h"
40 
41 namespace tensorflow {
42 namespace serving {
43 namespace internal {
44 
45 constexpr int kInitialBatchSize = 500;
46 
48  // Wrapper class for a single TfLite Interpreter for use in an interpreter
49  // pool.
50  public:
51  // Create an interpreter and external context and wrap it in the class
52  // for use with an InterpreterPool.
53  static Status CreateTfLiteInterpreterWrapper(
54  const tflite::FlatBufferModel& model,
55  const tensorflow::SessionOptions& options,
56  std::unique_ptr<TfLiteInterpreterWrapper>& wrapper);
57 
58  // Constructor for wrapper takes only an initialized interpreter.
60  std::unique_ptr<tflite::ExternalCpuBackendContext> external_context,
61  std::unique_ptr<tflite::Interpreter> interpreter);
62 
63  TfLiteInterpreterWrapper(std::unique_ptr<tflite::Interpreter> interpreter)
64  : TfLiteInterpreterWrapper(nullptr, std::move(interpreter)) {}
65 
66  // Returns the underlying interpreter.
67  tflite::Interpreter* Get() { return interpreter_.get(); }
68 
69  // Get the allocated batch size of the interpreter.
70  int GetBatchSize() { return batch_size_; }
71 
72  // Set the batch size.
73  void SetBatchSize(int batch_size) { batch_size_ = batch_size; }
74 
75  // Invokes the interpreter.
76  TfLiteStatus Invoke();
77 #ifdef TFLITE_PROFILE
78  void WriteOutput(const std::string& header, const string& data,
79  std::ostream* stream) {
80  (*stream) << header << std::endl;
81  (*stream) << data << std::endl;
82  }
83 
84  void WriteProfileData() {
85  if (run_summarizer_.HasProfiles()) {
86  WriteOutput("Operator-wise Profiling Info for Regular Benchmark Runs:",
87  run_summarizer_.GetOutputString(), &std::cout);
88  }
89  }
90 #endif
91 
92  // Sets the contents of the internal buffer _tensor_buffer_ to the tflite
93  // formatted string buffer equivalent stored in `batch` and sets
94  // raw pointer of `tflite_tensor` to the internal buffer. If the required
95  // size is larger than the current size, will allocate new memory and
96  // free the existing buffer.
97  tensorflow::Status SetStringData(const std::vector<const Tensor*>& tensors,
98  TfLiteTensor* tflite_tensor,
99  int tensor_index, int batch_size);
100 
101  private:
102  // External cpu context to enable caching.
103  std::unique_ptr<tflite::ExternalCpuBackendContext> external_context_;
104  std::unique_ptr<tflite::Interpreter> interpreter_;
105  int batch_size_ = 1;
106  std::map<int, std::unique_ptr<char>> tensor_buffer_;
107  std::map<int, size_t> tensor_buffer_max_bytes_;
108  std::vector<int32_t> offset_;
109 #ifdef TFLITE_PROFILE
110  int max_num_entries_;
111  tflite::profiling::ProfileSummarizer run_summarizer_;
112  tflite::profiling::BufferedProfiler profiler_;
113  int invocation_count_ = 0;
114 #endif
115 };
116 
117 // Contains a vector of TfLiteInterpreterWrapper, which are protected by mutex.
118 // When GetInterpreter is called, will either release a unique ptr to the
119 // caller or block if the vector is empty.
121  public:
122  // Creates a TfLiteSessionPool with model, session options,
123  // pool_size number of interpreters.
124  static tensorflow::Status CreateTfLiteInterpreterPool(
125  const tflite::FlatBufferModel* model,
126  const tensorflow::SessionOptions& options, int pool_size,
127  std::unique_ptr<TfLiteInterpreterPool>& interpreter_pool);
128 
129  // Returns a TFLite interpreter wrapper object. Caller may *block* waiting for
130  // a free interpreter pool to be available.
131  std::unique_ptr<TfLiteInterpreterWrapper> GetInterpreter() {
132  auto interpreter_available = [this]() ABSL_SHARED_LOCKS_REQUIRED(mutex_) {
133  return !this->available_.empty();
134  };
135  mutex_.LockWhen(absl::Condition(&interpreter_available));
136  auto pool = std::move(available_.back());
137  available_.pop_back();
138  mutex_.Unlock();
139  return pool;
140  }
141 
142  // Returns an interpreter wrapper to the available pool.
143  void ReturnInterpreter(
144  std::unique_ptr<TfLiteInterpreterWrapper> interpreter) {
145  absl::MutexLock l(&mutex_);
146  available_.emplace_back(std::move(interpreter));
147  }
148 
149  private:
151  std::vector<std::unique_ptr<TfLiteInterpreterWrapper>> interpreters)
152  : available_(std::move(interpreters)) {}
153  mutable absl::Mutex mutex_;
154  std::vector<std::unique_ptr<TfLiteInterpreterWrapper>> available_
155  ABSL_GUARDED_BY(mutex_);
156 };
157 
158 } // namespace internal
159 } // namespace serving
160 } // namespace tensorflow
161 
162 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_INTERPRETER_POOL_H_