TensorFlow Serving C++ API Documentation
tflite_session.cc
1 /* Copyright 2019 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tflite_session.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/functional/bind_front.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/lite/c/common.h"
35 #include "tensorflow/lite/kernels/cpu_backend_context.h"
36 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
37 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
38 #include "tensorflow/lite/kernels/register.h"
39 #include "tensorflow/lite/string_util.h"
40 #include "tensorflow/lite/tools/signature/signature_def_util.h"
41 #include "tensorflow/lite/util.h"
42 #include "tensorflow_serving/batching/incremental_barrier.h"
43 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
44 
45 namespace tensorflow {
46 namespace serving {
47 
48 // Map of TFLite tensor name to <TF TensorInfo, TFLite tensor index>.
49 namespace {
50 
51 Status TfLiteTypeToTfType(TfLiteType tflite_type, DataType* type) {
52  switch (tflite_type) {
53  case kTfLiteNoType:
54  *type = tensorflow::DT_INVALID;
55  break;
56  case kTfLiteFloat32:
57  *type = tensorflow::DT_FLOAT;
58  break;
59  case kTfLiteInt32:
60  *type = tensorflow::DT_INT32;
61  break;
62  case kTfLiteUInt8:
63  *type = tensorflow::DT_UINT8;
64  break;
65  case kTfLiteInt64:
66  *type = tensorflow::DT_INT64;
67  break;
68  case kTfLiteString:
69  *type = tensorflow::DT_STRING;
70  break;
71  case kTfLiteBool:
72  *type = tensorflow::DT_BOOL;
73  break;
74  case kTfLiteInt16:
75  *type = tensorflow::DT_INT16;
76  break;
77  case kTfLiteComplex64:
78  *type = tensorflow::DT_COMPLEX64;
79  break;
80  case kTfLiteInt8:
81  *type = tensorflow::DT_INT8;
82  break;
83  default:
84  return errors::Internal("Unknown TfLite type: ", tflite_type);
85  }
86  return absl::OkStatus();
87 }
88 
89 std::string TfToTfLiteLegacyTensorName(const string& tf_name) {
90  // TF variable names have ':0' suffix, early versions of the TF Lite converter
91  // used to strip this suffix.
92  std::pair<absl::string_view, absl::string_view> name_index =
93  absl::StrSplit(tf_name, absl::MaxSplits(':', 1));
94  return std::string(name_index.first);
95 }
96 
97 // Checks that an input/output tensor actually exists. If not, attempts to
98 // update the tensor name with legacy TFLite tensor naming.
99 Status FixTfLiteTensorName(const std::map<string, int>& tensor_name_map,
100  string& tensor_name) {
101  if (tensor_name_map.find(tensor_name) != tensor_name_map.end()) {
102  return absl::OkStatus();
103  }
104 
105  // Try to update with the legacy tflite tensor name.
106  const string& legacy_tflite_name = TfToTfLiteLegacyTensorName(tensor_name);
107  if (tensor_name_map.find(legacy_tflite_name) != tensor_name_map.end()) {
108  tensor_name = legacy_tflite_name;
109  return absl::OkStatus();
110  }
111 
112  return errors::Internal("Unknown tensor '", tensor_name, "'.");
113 }
114 
115 Status TfLiteTensorToTensorInfo(const TfLiteTensor* tflite_tensor,
116  TensorInfo* info) {
117  DataType tf_type;
118  TF_RETURN_IF_ERROR(TfLiteTypeToTfType(tflite_tensor->type, &tf_type));
119  info->set_dtype(tf_type);
120  info->set_name(tflite_tensor->name);
121  for (int i = 0; i < tflite_tensor->dims->size; i++) {
122  info->mutable_tensor_shape()->add_dim()->set_size(
123  tflite_tensor->dims->data[i]);
124  }
125  return absl::OkStatus();
126 }
127 
128 Status GetTensorInfoMap(const tflite::Interpreter* interpreter, bool input,
129  TensorInfoMap* infomap) {
130  const std::vector<int>& indices =
131  input ? interpreter->inputs() : interpreter->outputs();
132  const string& input_str = input ? "Input" : "Output";
133  for (int index : indices) {
134  const TfLiteTensor* tensor = interpreter->tensor(index);
135  if (tensor->name == nullptr) {
136  return errors::Internal(input_str,
137  " name missing for tensor index: ", index);
138  }
139  TensorInfo info;
140  TF_RETURN_IF_ERROR(TfLiteTensorToTensorInfo(tensor, &info));
141  if (!infomap->emplace(tensor->name, std::pair<TensorInfo, int>(info, index))
142  .second) {
143  return errors::AlreadyExists(input_str, " tensor name: ", tensor->name,
144  " has multiple indices");
145  }
146  }
147  return absl::OkStatus();
148 }
149 
150 std::vector<int> TensorDims(const Tensor& tensor) {
151  std::vector<int> dims(tensor.dims());
152  for (int i = 0; i < tensor.dims(); ++i) {
153  dims[i] = static_cast<int>(tensor.dim_size(i));
154  }
155  return dims;
156 }
157 
158 // Create output tensors making sure they are the right size. //
159 Status CreateOutputTensors(
160  std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
161  const std::vector<string>& output_tensor_names,
162  const std::map<string, int>& output_tensor_to_idx,
163  std::map<int32_t, Tensor*>& tflite_idx_to_output_tensor,
164  std::vector<Tensor>* output_tensors) {
165  output_tensors->reserve(output_tensor_names.size());
166  for (std::string tfname : output_tensor_names) {
167  auto fix_status = FixTfLiteTensorName(output_tensor_to_idx, tfname);
168  if (fix_status != absl::OkStatus()) {
169  return errors::Internal("Missing output TFLite tensor: ", tfname, ": ",
170  fix_status.message());
171  }
172  const int tflite_idx = output_tensor_to_idx.at(tfname);
173  TensorShape tf_shape;
174  const auto& interpreter = interpreter_wrapper->Get();
175  const auto* tflite_tensor = interpreter->tensor(tflite_idx);
176  for (int i = 0; i < tflite_tensor->dims->size; ++i) {
177  tf_shape.AddDim(tflite_tensor->dims->data[i]);
178  }
179  DataType tf_type;
180  TF_RETURN_IF_ERROR(TfLiteTypeToTfType(tflite_tensor->type, &tf_type));
181  output_tensors->emplace_back(tf_type, tf_shape);
182  tflite_idx_to_output_tensor[tflite_idx] = &output_tensors->back();
183  }
184  return absl::OkStatus();
185 }
186 
187 Status SetInputAndInvokeMiniBatch(
188  std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
189  const std::vector<int>& tflite_input_indices,
190  const std::vector<std::vector<const Tensor*>>& inputs, int batch_size,
191  int* fixed_batch_size) {
192  auto* interpreter = interpreter_wrapper->Get();
193  // Load input data from Tensorflow tensors.
194  for (int i = 0; i < tflite_input_indices.size(); ++i) {
195  int tflite_input_idx = tflite_input_indices[i];
196  auto tflite_input_tensor = interpreter->tensor(tflite_input_idx);
197  const auto& tf_input_tensors = inputs[i];
198  if (tflite_input_tensor->type != kTfLiteString) {
199  const Tensor* tf_input_tensor = tf_input_tensors[0];
200  // concated.tensor_data() may be accessed later.
201  Tensor concated;
202  if (tf_input_tensors.size() > 1) {
203  std::vector<Tensor> to_concatenate;
204  to_concatenate.reserve(tf_input_tensors.size());
205  for (const auto* t : tf_input_tensors) {
206  to_concatenate.push_back(std::move(*t));
207  }
208  TF_RETURN_IF_ERROR(tensor::Concat(to_concatenate, &concated));
209  tf_input_tensor = &concated;
210  }
211  auto tensor_bytes = tf_input_tensor->tensor_data();
212  std::vector<int> tf_dims = TensorDims(*tf_input_tensor);
213  std::vector<int> tflite_dims(
214  tflite_input_tensor->dims->data,
215  tflite_input_tensor->dims->data + tflite_input_tensor->dims->size);
216  if (tensor_bytes.size() != tflite_input_tensor->bytes ||
217  tf_dims != tflite_dims) {
218  if (interpreter->ResizeInputTensor(tflite_input_idx, tf_dims) !=
219  kTfLiteOk) {
220  return errors::Internal(
221  "Failed to resize input tensor: ", tflite_input_tensor->name,
222  " from ", tflite_input_tensor->bytes, " to ", tensor_bytes.size(),
223  " bytes.");
224  }
225  if (interpreter->AllocateTensors() != kTfLiteOk) {
226  return errors::Internal("Failed to allocate tensors");
227  }
228  }
229  std::memcpy(tflite_input_tensor->data.raw, tensor_bytes.data(),
230  tensor_bytes.size());
231  } else {
232  // Copy the string tensor data to the input tflite tensor.
233  const bool needs_resize =
234  fixed_batch_size ? batch_size > interpreter_wrapper->GetBatchSize()
235  : batch_size != interpreter_wrapper->GetBatchSize();
236  if (needs_resize) {
237  // std::cout << "resizing to: " << batch_size << std::endl;
238  interpreter->ResizeInputTensor(tflite_input_idx, {batch_size});
239  interpreter_wrapper->SetBatchSize(batch_size);
240  if (interpreter->AllocateTensors() != kTfLiteOk) {
241  return errors::Internal("Failed to allocate tensors");
242  }
243  }
244  if (fixed_batch_size) {
245  *fixed_batch_size = interpreter_wrapper->GetBatchSize();
246  }
247  TF_RETURN_IF_ERROR(interpreter_wrapper->SetStringData(
248  tf_input_tensors, tflite_input_tensor, tflite_input_idx, batch_size));
249  }
250  }
251  if (interpreter_wrapper->Invoke() != kTfLiteOk) {
252  return errors::Internal("Failed to invoke TfLite interpreter");
253  }
254  return absl::OkStatus();
255 }
256 
257 Status SetMiniBatchOutput(
258  std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
259  const std::map<int, Tensor*>& tflite_idx_to_output_tensor,
260  std::vector<Tensor>* outputs) {
261  for (const auto& entry : tflite_idx_to_output_tensor) {
262  Tensor* tensor = entry.second;
263  const DataType tf_type = tensor->dtype();
264  if (tensor->NumElements() == 0) {
265  continue;
266  }
267  const auto* interpreter = interpreter_wrapper->Get();
268  auto tflite_tensor = interpreter->tensor(entry.first);
269  if (DataTypeCanUseMemcpy(tf_type)) {
270  auto tensor_bytes = tensor->tensor_data();
271  int offset = 0;
272  size_t tflite_tensor_bytes = tflite_tensor->bytes;
273  std::memcpy(const_cast<char*>(tensor_bytes.data() + offset),
274  tflite_tensor->data.raw, tflite_tensor_bytes);
275  } else if (tflite_tensor->type == kTfLiteString) {
276  const int string_count = tflite::GetStringCount(tflite_tensor);
277  int num_strings = string_count;
278  int offset = 0;
279  auto str_tensors = tensor->flat<tstring>();
280  for (int i = 0; i < num_strings; i++) {
281  const auto& ref = tflite::GetString(tflite_tensor, i);
282  str_tensors(i + offset).assign(ref.str, ref.len);
283  }
284  }
285  }
286  return absl::OkStatus();
287 }
288 
289 int GetModelBatchSize(const tflite::Model* model) {
290  const auto* primary_subgraph = model->subgraphs()->Get(0);
291  const auto* inputs = primary_subgraph->inputs();
292  if (inputs->size() == 1) {
293  // Only models with 1 input tensor can be batched, since SplitTFLiteTask
294  // only works on a single input tensor jobs.
295  const int tensor_id = inputs->Get(0);
296  const auto* tensor = primary_subgraph->tensors()->Get(tensor_id);
297  return tensor->shape()->Get(0);
298  }
299  return -1;
300 }
301 
302 } // namespace
303 
304 // Split an input task up into multiple tasks.
305 Status TfLiteSession::SplitTfLiteInputTask(
306  std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
307  int open_batch_remaining_slot, int max_batch_size,
308  std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks) {
309  auto* input_task = input_task_ptr->get();
310  auto split_output =
311  std::make_shared<std::vector<std::unique_ptr<std::vector<Tensor>>>>();
312  auto partial_status = std::make_shared<ThreadSafeStatus>();
313  auto split_task_done_callback = [split_output, partial_status, input_task]() {
314  // notify the input task.
315  auto cleanup = gtl::MakeCleanup([done_notification = input_task->done]() {
316  done_notification->Notify();
317  });
318 
319  // partial status is set during actual running.
320  if (!partial_status->status().ok()) {
321  *input_task->status = partial_status->status();
322  return;
323  }
324 
325  // get the total number of tensors to concatenate (number of tasks)
326  int output_size = split_output->size();
327  // each split contains the same number of output tensors.
328  int tensor_size = (*split_output)[0]->size();
329 
330  // for each tensor output
331  for (int tensor_idx = 0; tensor_idx < tensor_size; ++tensor_idx) {
332  Tensor output_tensor; // the concatened tensor
333  std::vector<Tensor> to_concatenate;
334  to_concatenate.reserve(output_size);
335  // for each split task concatenate the output
336  for (int output_idx = 0; output_idx < output_size; ++output_idx) {
337  to_concatenate.push_back(
338  std::move((*(*split_output)[output_idx])[tensor_idx]));
339  }
340  const auto concat_status = tensor::Concat(to_concatenate, &output_tensor);
341  if (!concat_status.ok()) {
342  *input_task->status = concat_status;
343  return;
344  }
345  // add the concatenated tensor to input_tasks output
346  input_task->outputs->push_back(output_tensor);
347  }
348  *input_task->status = absl::OkStatus();
349  };
350 
351  // The Callback will be run only after all partial tasks finished.
352  IncrementalBarrier barrier(std::move(split_task_done_callback));
353  std::vector<int64_t> output_task_sizes;
354 
355  if (open_batch_remaining_slot > 0) {
356  output_task_sizes.push_back(open_batch_remaining_slot);
357  split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
358  }
359 
360  for (int left_task_size = input_task->size() - open_batch_remaining_slot;
361  left_task_size > 0; left_task_size -= max_batch_size) {
362  int next_task_size = std::min(left_task_size, max_batch_size);
363  output_task_sizes.push_back(next_task_size);
364  split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
365  }
366 
367  const int output_task_num = output_task_sizes.size();
368  output_tasks->reserve(output_task_num);
369  for (int i = 0; i < output_task_num; ++i) {
370  std::unique_ptr<TfLiteBatchTask> task;
371  TfLiteBatchTask::CreatePartialTfLiteBatchTask(
372  input_task->input_indices, input_task->output_tensor_names,
373  (*split_output)[i].get(), barrier.Inc(), partial_status.get(), &task);
374  output_tasks->push_back(std::move(task));
375  }
376 
377  for (int i = 0; i < input_task->inputs.size(); ++i) {
378  const Tensor& input = input_task->inputs[i];
379  std::vector<Tensor> split_tensors;
380  auto status = tensor::Split(input, output_task_sizes, &split_tensors);
381  if (status != absl::OkStatus()) {
382  return status;
383  }
384  for (int output_idx = 0; output_idx < output_task_num; ++output_idx) {
385  auto& output_task = (*output_tasks)[output_idx];
386  output_task->inputs.push_back(std::move(split_tensors[output_idx]));
387  }
388  }
389  return absl::OkStatus();
390 }
391 
392 Status TfLiteSession::CreateDefaultBasicBatchScheduler(
393  const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
394  std::function<void(std::unique_ptr<Batch<TfLiteBatchTask>>)>
395  process_batch_callback,
396  std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>* batch_scheduler) {
397  std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>> basic_batch_scheduler;
398  TF_RETURN_IF_ERROR(BasicBatchScheduler<TfLiteBatchTask>::Create(
399  options, process_batch_callback, &basic_batch_scheduler));
400  *batch_scheduler = std::move(basic_batch_scheduler);
401  return absl::OkStatus();
402 }
403 
404 Status TfLiteSession::SetScheduler(
405  const SchedulerCreator& scheduler_creator,
406  const BasicBatchScheduler<TfLiteBatchTask>::Options& options) {
407  use_fixed_batch_size_ = true;
408  scheduler_options_ = options;
409  auto bound_scheduler_creator = absl::bind_front(
410  &TfLiteSession::CreateDefaultBasicBatchScheduler, scheduler_options_);
411  return bound_scheduler_creator(
412  [this](std::unique_ptr<Batch<TfLiteBatchTask>> batch) {
413  this->ProcessBatch(std::move(batch));
414  },
415  &scheduler_);
416 }
417 
418 Status TfLiteSession::Create(string&& buffer, const SessionOptions& options,
419  int num_pools, int num_interpreters_per_pool,
420  std::unique_ptr<TfLiteSession>* tflite_session,
421  ::google::protobuf::Map<string, SignatureDef>* signatures) {
422  auto model = tflite::FlatBufferModel::BuildFromModel(
423  flatbuffers::GetRoot<tflite::Model>(buffer.data()));
424  if (model == nullptr) {
425  return errors::InvalidArgument("Cannot build FlatBufferModel from buffer.");
426  }
427 
428  tflite::ops::builtin::BuiltinOpResolver resolver;
429  tflite::ops::custom::AddParseExampleOp(&resolver);
430 
431  std::unique_ptr<tflite::Interpreter> interpreter;
432  if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
433  return errors::Internal("Cannot build Interpreter from buffer.");
434  }
435 
436  TensorInfoMap inputs;
437  TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), true, &inputs));
438  TensorInfoMap outputs;
439  TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), false, &outputs));
440 
441  // Map of TFLite tensor name -> tensor index
442  std::map<string, int> input_tensor_to_index;
443  std::map<string, int> output_tensor_to_index;
444  for (const auto& info : inputs) {
445  const string& tflite_tensor_name = info.first;
446  input_tensor_to_index[tflite_tensor_name] = info.second.second;
447  }
448  for (const auto& info : outputs) {
449  const string& tflite_tensor_name = info.first;
450  output_tensor_to_index[tflite_tensor_name] = info.second.second;
451  }
452 
453  // Attempt to read signature defs from the model file
454  std::map<string, SignatureDef> signature_defs;
455  const auto status =
456  tflite::GetSignatureDefMap(model->GetModel(), &signature_defs);
457  if (status != absl::OkStatus()) {
458  return errors::InvalidArgument(
459  "Invalid SignatureDefs found in TfLite model: ", status.message());
460  }
461  const bool has_lite_signature_def = !signature_defs.empty();
462 
463  signatures->clear();
464  if (has_lite_signature_def) {
465  // Check that input/output tensors in the signature defs refer to existing
466  // tensors.
467  // If not found, try to match with legacy TFLite name (without suffix).
468  for (const auto& signature_item : signature_defs) {
469  SignatureDef* tflite_signature = &(*signatures)[signature_item.first];
470  tflite_signature->CopyFrom(signature_item.second);
471  for (auto& input : *tflite_signature->mutable_inputs()) {
472  TensorInfo* tensor_info = &input.second;
473  TF_RETURN_WITH_CONTEXT_IF_ERROR(
474  FixTfLiteTensorName(input_tensor_to_index,
475  *tensor_info->mutable_name()),
476  "Signature input ", input.first, " references an unknown tensor");
477  }
478  for (auto& output : *tflite_signature->mutable_outputs()) {
479  TensorInfo* tensor_info = &output.second;
480  TF_RETURN_WITH_CONTEXT_IF_ERROR(
481  FixTfLiteTensorName(output_tensor_to_index,
482  *tensor_info->mutable_name()),
483  "Signature output ", output.first, " references an unknown tensor");
484  }
485  }
486  } else {
487  // Build a mock signature from the input/output tensors of the model.
488  // TODO(b/169239308)
489  LOG(WARNING) << "No signature def found in TFLite model. Generating one.";
490  SignatureDef* sigdef = &(*signatures)[kDefaultServingSignatureDefKey];
491  for (const auto& info : inputs) {
492  string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
493  (*sigdef->mutable_inputs())[tflite_tensor_name] = info.second.first;
494  }
495  for (const auto& info : outputs) {
496  string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
497  (*sigdef->mutable_outputs())[tflite_tensor_name] = info.second.first;
498  }
499  sigdef->set_method_name(kPredictMethodName);
500  }
501 
502  const int num_interpreters = std::max(1, num_pools);
503  const int model_batch_size = GetModelBatchSize(model->GetModel());
504 
505  std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool;
506  TF_RETURN_IF_ERROR(
507  internal::TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
508  model.get(), options, num_interpreters, interpreter_pool));
509 
510  tflite_session->reset(new TfLiteSession(
511  std::move(input_tensor_to_index), std::move(output_tensor_to_index),
512  std::move(buffer), std::move(model), std::move(interpreter_pool)));
513 
514  if (num_interpreters_per_pool > 1) {
515  const int default_allowed_batch =
516  (internal::kInitialBatchSize + num_interpreters_per_pool - 1) /
517  num_interpreters_per_pool;
518  const int min_allowed_batch =
519  model_batch_size > 1 ? model_batch_size : default_allowed_batch;
520  const int max_enqueued_batches = num_interpreters * 100;
521  BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options;
522  scheduler_options.num_batch_threads = num_interpreters;
523  scheduler_options.max_batch_size = internal::kInitialBatchSize;
524  scheduler_options.enable_large_batch_splitting = true;
525  scheduler_options.max_execution_batch_size = min_allowed_batch;
526  scheduler_options.max_enqueued_batches = max_enqueued_batches;
527  scheduler_options.split_input_task_func = SplitTfLiteInputTask;
528  TF_RETURN_IF_ERROR(
529  (*tflite_session)
530  ->SetScheduler(&TfLiteSession::CreateDefaultBasicBatchScheduler,
531  scheduler_options));
532  }
533  return absl::OkStatus();
534 }
535 
536 TfLiteSession::TfLiteSession(
537  std::map<string, int>&& input_tensor_to_index,
538  std::map<string, int>&& output_tensor_to_index, string&& buffer,
539  std::unique_ptr<tflite::FlatBufferModel> model,
540  std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool)
541  : input_tensor_to_index_(std::move(input_tensor_to_index)),
542  output_tensor_to_index_(std::move(output_tensor_to_index)),
543  model_serialized_bytes_(std::move(buffer)),
544  model_(std::move(model)),
545  interpreter_pool_(std::move(interpreter_pool)) {}
546 
547 Status TfLiteSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
548  const std::vector<string>& output_tensor_names,
549  const std::vector<string>& target_node_names,
550  std::vector<Tensor>* outputs) {
551  RunMetadata run_metadata;
552  return Run(RunOptions(), inputs, output_tensor_names, target_node_names,
553  outputs, &run_metadata);
554 }
555 
556 Status TfLiteSession::Run(const RunOptions& run_options,
557  const std::vector<std::pair<string, Tensor>>& inputs,
558  const std::vector<string>& output_tensor_names,
559  const std::vector<string>& target_node_names,
560  std::vector<Tensor>* outputs,
561  RunMetadata* run_metadata) {
562  return Run(run_options, inputs, output_tensor_names, target_node_names,
563  outputs, run_metadata, thread::ThreadPoolOptions());
564 }
565 
566 Status TfLiteSession::RunInternal(
567  const std::vector<int>& tflite_input_indices,
568  const std::vector<std::vector<const Tensor*>>& merged_inputs,
569  const std::vector<string>& output_tensor_names,
570  std::vector<Tensor>* combined_outputs, int batch_size,
571  int* fixed_batch_size) {
572 #define RETURN_POOL_IF_ERROR(...) \
573  do { \
574  ::tensorflow::Status _status = (__VA_ARGS__); \
575  if (TF_PREDICT_FALSE(!_status.ok())) { \
576  interpreter_pool_->ReturnInterpreter(std::move(interpreter)); \
577  return _status; \
578  } \
579  } while (0);
580  auto interpreter = interpreter_pool_->GetInterpreter();
581  RETURN_POOL_IF_ERROR(
582  SetInputAndInvokeMiniBatch(interpreter, tflite_input_indices,
583  merged_inputs, batch_size, fixed_batch_size));
584 
585  // Create return tensors and map the tflite tensor index to the
586  // index of the created tensor.
587  std::map<int32_t, Tensor*> tflite_idx_to_output_tensor;
588  RETURN_POOL_IF_ERROR(CreateOutputTensors(
589  interpreter, output_tensor_names, output_tensor_to_index_,
590  tflite_idx_to_output_tensor, combined_outputs));
591 
592  // Set the contents of the return tensors.
593  RETURN_POOL_IF_ERROR(SetMiniBatchOutput(
594  interpreter, tflite_idx_to_output_tensor, combined_outputs));
595 
596 #undef RETURN_POOL_IF_ERROR
597  interpreter_pool_->ReturnInterpreter(std::move(interpreter));
598  return absl::OkStatus();
599 }
600 
601 Status TfLiteSession::Run(
602  const RunOptions& run_options,
603  const std::vector<std::pair<string, Tensor>>& inputs,
604  const std::vector<string>& output_tensor_names,
605  const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
606  RunMetadata* run_metadata,
607  const thread::ThreadPoolOptions& thread_pool_options) {
608  std::map<int, const Tensor*> tflite_idx_to_input_tensor;
609  for (const auto& input : inputs) {
610  string name = input.first;
611  TF_RETURN_WITH_CONTEXT_IF_ERROR(
612  FixTfLiteTensorName(input_tensor_to_index_, name),
613  "Missing input TFLite tensor: ", name);
614  const int index = input_tensor_to_index_.at(name);
615  tflite_idx_to_input_tensor[index] = &input.second;
616  }
617  outputs->reserve(output_tensor_names.size());
618  if (!scheduler_) {
619  std::vector<int> input_indices;
620  std::vector<std::vector<const Tensor*>> inputs;
621  for (const auto entry : tflite_idx_to_input_tensor) {
622  const auto& tf_tensor = *entry.second;
623  inputs.push_back({&tf_tensor});
624  input_indices.push_back(entry.first);
625  }
626  const int batch_size =
627  inputs.empty() || inputs[0].empty() ? 1 : inputs[0][0]->dim_size(0);
628  return RunInternal(input_indices, inputs, output_tensor_names, outputs,
629  batch_size);
630  }
631  Notification done;
632  Status status;
633  std::unique_ptr<TfLiteBatchTask> task;
634  TfLiteBatchTask::CreateTfLiteBatchTask(&output_tensor_names, outputs, &done,
635  &status, &task);
636  for (const auto entry : tflite_idx_to_input_tensor) {
637  task->input_indices.push_back(entry.first);
638  task->inputs.push_back(std::move(*entry.second));
639  }
640  TF_RETURN_IF_ERROR(scheduler_->Schedule(&task));
641  done.WaitForNotification();
642  return status;
643 }
644 
645 Status TfLiteSession::ListDevices(std::vector<DeviceAttributes>* response) {
646  return errors::Unimplemented("ListDevices is not yet supported.");
647 }
648 
649 Status MergeInputTensors(const Batch<TfLiteBatchTask>& batch,
650  std::vector<std::vector<const Tensor*>>* merged_inputs,
651  int* batch_size) {
652  if (batch.num_tasks() < 1) {
653  return errors::Internal("Batch size expected to be positive; was ",
654  batch.num_tasks());
655  }
656  const int tensors_per_task = batch.task(0).inputs.size();
657  *batch_size = 0;
658  // each entry in merged_inputs is a list of task tensors.
659  for (int i = 0; i < tensors_per_task; ++i) {
660  merged_inputs->emplace_back();
661  std::vector<const Tensor*>& tensors_to_merge = merged_inputs->back();
662  for (int j = 0; j < batch.num_tasks(); ++j) {
663  const std::vector<Tensor>& inputs = batch.task(j).inputs;
664  tensors_to_merge.push_back(&(inputs[i]));
665  if (i == 0) {
666  if (inputs[i].dims()) {
667  *batch_size += inputs[i].dim_size(0);
668  }
669  }
670  }
671  }
672  return absl::OkStatus();
673 }
674 
675 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
676  Batch<TfLiteBatchTask>* batch, int batch_size) {
677  std::vector<int64_t> task_sizes(batch->num_tasks());
678  int total_size = 0;
679  for (int i = 0; i < batch->num_tasks(); ++i) {
680  const int task_size = batch->task(i).size();
681  task_sizes[i] = task_size;
682  total_size += task_size;
683  }
684 
685  if (total_size < batch_size) {
686  task_sizes.push_back(batch_size - total_size);
687  }
688 
689  for (int i = 0; i < combined_outputs.size(); i++) {
690  const auto& output_tensor = combined_outputs[i];
691  std::vector<Tensor> split_tensor;
692  const Status split_status =
693  tensor::Split(output_tensor, task_sizes, &split_tensor);
694  if (!split_status.ok()) {
695  return errors::Internal("Tensor split operation failed: ",
696  split_status.ToString());
697  }
698  for (int j = 0; j < batch->num_tasks(); ++j) {
699  TfLiteBatchTask& task = *(batch->mutable_task(j));
700  task.set_output(split_tensor[j]);
701  }
702  }
703 
704  return absl::OkStatus();
705 }
706 
707 void TfLiteSession::ProcessBatch(
708  std::unique_ptr<Batch<TfLiteBatchTask>> batch) {
709  // As a possible performance optimization, consider overlapping the tensor
710  // concatenation with waiting for the batch to close (i.e. do the
711  // concatenation incrementally as tasks stream into the batch).
712  batch->WaitUntilClosed();
713 
714  if (batch->empty()) {
715  return;
716  }
717 
718  const uint64_t dequeue_time_micros = EnvTime::NowMicros();
719 
720  // Regardless of the outcome, we need to propagate the status to the
721  // individual tasks and signal that they are done. We use MakeCleanup() to
722  // ensure that this happens no matter how we exit the method below.
723  Status status;
724  auto finally = gtl::MakeCleanup([&status, &batch] {
725  for (int i = 0; i < batch->num_tasks(); ++i) {
726  TfLiteBatchTask* task = batch->mutable_task(i);
727  if (task->is_partial) {
728  task->partial_status->Update(status);
729  task->done_callback();
730  } else {
731  *batch->mutable_task(i)->status = status;
732  batch->mutable_task(i)->done->Notify();
733  }
734  }
735  });
736 
737  // Make sure we have at least one task that hasn't exceeded its timeout from
738  // queue time alone, and find the latest task deadline which we'll use for the
739  // overall batch.
740  bool all_tasks_timeout_exceeded = true;
741  uint64_t batch_deadline_micros = 0;
742  for (int i = 0; i < batch->num_tasks(); ++i) {
743  const TfLiteBatchTask& task = batch->task(i);
744  // If the caller doesn't populate RunOptions, the timeout is 0 by default.
745  // Interpret that as "no timeout".
746  if (task.run_options.timeout_in_ms() <= 0) {
747  all_tasks_timeout_exceeded = false;
748  break;
749  }
750  const int64_t task_timeout_micros = task.run_options.timeout_in_ms() * 1000;
751  const uint64_t task_deadline_micros =
752  task.enqueue_time_micros + task_timeout_micros;
753  if (task_deadline_micros > dequeue_time_micros) {
754  all_tasks_timeout_exceeded = false;
755  if (task_deadline_micros > batch_deadline_micros) {
756  batch_deadline_micros = task_deadline_micros;
757  }
758  }
759  }
760  if (all_tasks_timeout_exceeded) {
761  status = Status(static_cast<tensorflow::errors::Code>(
762  absl::StatusCode::kResourceExhausted),
763  "Run() timeout exceeded while waiting in batching queue");
764  return;
765  }
766 
767  std::vector<std::vector<const Tensor*>> merged_inputs;
768  int batch_size = 0;
769  status = MergeInputTensors(*batch, &merged_inputs, &batch_size);
770  if (!status.ok()) {
771  return;
772  }
773  std::vector<Tensor> combined_outputs;
774  const auto& tflite_input_indices = batch->task(0).input_indices;
775  auto& output_tensor_names = batch->task(0).output_tensor_names;
776  int fixed_batch_size = batch_size;
777  status = RunInternal(tflite_input_indices, merged_inputs,
778  *output_tensor_names, &combined_outputs, batch_size,
779  use_fixed_batch_size_ ? &fixed_batch_size : nullptr);
780  if (!status.ok()) {
781  return;
782  }
783  // The size of the batch might be smaller than the fixed_batch_size.
784  status = SplitOutputTensors(combined_outputs, batch.get(), fixed_batch_size);
785 }
786 
787 } // namespace serving
788 } // namespace tensorflow