16 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/lite/external_cpu_backend_context.h"
26 #include "tensorflow/lite/kernels/cpu_backend_context.h"
27 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
28 #include "tensorflow/lite/kernels/register.h"
30 namespace tensorflow {
34 TfLiteInterpreterWrapper::TfLiteInterpreterWrapper(
35 std::unique_ptr<tflite::ExternalCpuBackendContext> external_context,
36 std::unique_ptr<tflite::Interpreter> interpreter)
39 external_context_(std::move(external_context)),
40 interpreter_(std::move(interpreter)),
41 max_num_entries_(TFLITE_PROFILE_EVENTS),
42 profiler_(max_num_entries_)
44 external_context_(std::move(external_context)),
45 interpreter_(std::move(interpreter))
49 interpreter_->SetProfiler(&profiler_);
51 for (
const int& idx : interpreter_->inputs()) {
52 const auto* tflite_tensor = interpreter_->tensor(idx);
53 if (tflite_tensor->type == kTfLiteString) {
54 tensor_buffer_.emplace(idx,
55 std::unique_ptr<char>(tflite_tensor->data.raw));
56 tensor_buffer_max_bytes_[idx] = 0;
61 tensorflow::Status TfLiteInterpreterWrapper::SetStringData(
62 const std::vector<const Tensor*>& tensors, TfLiteTensor* tflite_tensor,
63 int tensor_index,
int batch_size) {
69 int32_t num_strings = batch_size;
71 size_t total_size = 0;
72 offset_.push_back(
static_cast<int32_t
>(total_size));
73 for (
const auto& tensor : tensors) {
74 const auto& flat = tensor->flat<tstring>();
75 for (
int i = 0; i < flat.size(); ++i) {
76 total_size += flat(i).size();
77 offset_.push_back(
static_cast<int32_t
>(total_size));
80 size_t required_bytes = total_size +
sizeof(int32_t) * (num_strings + 2);
81 if (tensor_buffer_.find(tensor_index) == tensor_buffer_.end()) {
82 return errors::Internal(
"Tensor input for index not found: ", tensor_index);
84 if (required_bytes > tensor_buffer_max_bytes_[tensor_index]) {
85 if (tflite_tensor->data.raw) {
86 free(tflite_tensor->data.raw);
88 tflite_tensor->data.raw =
reinterpret_cast<char*
>(malloc(required_bytes));
89 tensor_buffer_max_bytes_[tensor_index] = required_bytes;
91 tensor_buffer_[tensor_index].reset(tflite_tensor->data.raw);
92 memcpy(tensor_buffer_[tensor_index].get(), &num_strings,
sizeof(int32_t));
93 int32_t start =
sizeof(int32_t) * (num_strings + 2);
94 for (
size_t i = 0; i < offset_.size(); i++) {
95 size_t size_offset_i = start + offset_[i];
96 if (size_offset_i > std::numeric_limits<int32_t>::max()) {
97 return errors::Internal(
"Invalid size, string input too large:",
100 int32_t offset_i =
static_cast<int32_t
>(size_offset_i);
101 memcpy(tensor_buffer_[tensor_index].get() +
sizeof(int32_t) * (i + 1),
102 &offset_i,
sizeof(int32_t));
104 for (
const auto& tensor : tensors) {
105 const auto& flat = tensor->flat<tstring>();
106 for (
int i = 0; i < flat.size(); ++i) {
107 memcpy(tensor_buffer_[tensor_index].get() + start, flat(i).data(),
109 start += flat(i).size();
114 tflite_tensor->data.raw = tensor_buffer_[tensor_index].release();
115 tflite_tensor->bytes = required_bytes;
116 tflite_tensor->allocation_type = kTfLiteDynamic;
117 return absl::OkStatus();
120 TfLiteStatus TfLiteInterpreterWrapper::Invoke() {
121 #ifdef TFLITE_PROFILE
122 if (invocation_count_ > 0) {
124 profiler_.StartProfiling();
127 auto status = interpreter_->Invoke();
128 #ifdef TFLITE_PROFILE
129 if (invocation_count_ > 0) {
130 profiler_.StopProfiling();
131 auto profile_events = profiler_.GetProfileEvents();
132 run_summarizer_.ProcessProfiles(profile_events, *interpreter_);
134 if (invocation_count_++ >= MAX_PROFILE_EVENTS) {
136 run_summarizer_.Clear();
137 invocation_count_ = 0;
143 tensorflow::Status TfLiteInterpreterWrapper::CreateTfLiteInterpreterWrapper(
144 const tflite::FlatBufferModel& model,
145 const tensorflow::SessionOptions& options,
146 std::unique_ptr<TfLiteInterpreterWrapper>& wrapper) {
147 tflite::ops::builtin::BuiltinOpResolver resolver;
148 tflite::ops::custom::AddParseExampleOp(&resolver);
149 std::unique_ptr<tflite::Interpreter> interpreter;
152 const int batch_size = 1;
154 const int num_threads = 1;
156 if (tflite::InterpreterBuilder(model, resolver)(&interpreter, num_threads) !=
158 return errors::Internal(
159 "Failed to create a TFLite interpreter with the given model");
161 std::unique_ptr<tflite::ExternalCpuBackendContext> external_context(
162 new tflite::ExternalCpuBackendContext());
163 std::unique_ptr<tflite::CpuBackendContext> cpu_backend_context(
164 new tflite::CpuBackendContext());
165 cpu_backend_context->SetUseCaching(
true);
166 cpu_backend_context->SetMaxNumThreads(num_threads);
167 external_context->set_internal_backend_context(
168 std::move(cpu_backend_context));
169 interpreter->SetExternalContext(kTfLiteCpuBackendContext,
170 external_context.get());
171 const int idx = interpreter->inputs()[0];
172 const auto* tensor = interpreter->tensor(idx);
173 if (tensor->type == kTfLiteString) {
174 if (interpreter->ResizeInputTensor(idx, {batch_size}) != kTfLiteOk) {
175 return errors::Internal(
"Failed to resize input");
178 if (interpreter->AllocateTensors() != kTfLiteOk) {
179 return errors::Internal(
"Failed to allocate tensors");
181 wrapper.reset(
new TfLiteInterpreterWrapper(std::move(external_context),
182 std::move(interpreter)));
183 return absl::OkStatus();
186 tensorflow::Status TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
187 const tflite::FlatBufferModel* model,
188 const tensorflow::SessionOptions& options,
int pool_size,
189 std::unique_ptr<TfLiteInterpreterPool>& interpreter_pool) {
190 std::vector<std::unique_ptr<TfLiteInterpreterWrapper>> interpreters(
192 for (
int i = 0; i < pool_size; i++) {
193 auto& wrapper = interpreters[i];
194 TF_RETURN_IF_ERROR(TfLiteInterpreterWrapper::CreateTfLiteInterpreterWrapper(
195 *model, options, wrapper));
197 interpreter_pool.reset(
new TfLiteInterpreterPool(std::move(interpreters)));
198 return absl::OkStatus();