16 #include "tensorflow_serving/servables/tensorflow/tflite_session.h"
26 #include "absl/functional/bind_front.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/lite/c/common.h"
35 #include "tensorflow/lite/kernels/cpu_backend_context.h"
36 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
37 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
38 #include "tensorflow/lite/kernels/register.h"
39 #include "tensorflow/lite/string_util.h"
40 #include "tensorflow/lite/tools/signature/signature_def_util.h"
41 #include "tensorflow/lite/util.h"
42 #include "tensorflow_serving/batching/incremental_barrier.h"
43 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
45 namespace tensorflow {
51 Status TfLiteTypeToTfType(TfLiteType tflite_type, DataType* type) {
52 switch (tflite_type) {
54 *type = tensorflow::DT_INVALID;
57 *type = tensorflow::DT_FLOAT;
60 *type = tensorflow::DT_INT32;
63 *type = tensorflow::DT_UINT8;
66 *type = tensorflow::DT_INT64;
69 *type = tensorflow::DT_STRING;
72 *type = tensorflow::DT_BOOL;
75 *type = tensorflow::DT_INT16;
77 case kTfLiteComplex64:
78 *type = tensorflow::DT_COMPLEX64;
81 *type = tensorflow::DT_INT8;
84 return errors::Internal(
"Unknown TfLite type: ", tflite_type);
86 return absl::OkStatus();
89 std::string TfToTfLiteLegacyTensorName(
const string& tf_name) {
92 std::pair<absl::string_view, absl::string_view> name_index =
93 absl::StrSplit(tf_name, absl::MaxSplits(
':', 1));
94 return std::string(name_index.first);
99 Status FixTfLiteTensorName(
const std::map<string, int>& tensor_name_map,
100 string& tensor_name) {
101 if (tensor_name_map.find(tensor_name) != tensor_name_map.end()) {
102 return absl::OkStatus();
106 const string& legacy_tflite_name = TfToTfLiteLegacyTensorName(tensor_name);
107 if (tensor_name_map.find(legacy_tflite_name) != tensor_name_map.end()) {
108 tensor_name = legacy_tflite_name;
109 return absl::OkStatus();
112 return errors::Internal(
"Unknown tensor '", tensor_name,
"'.");
115 Status TfLiteTensorToTensorInfo(
const TfLiteTensor* tflite_tensor,
118 TF_RETURN_IF_ERROR(TfLiteTypeToTfType(tflite_tensor->type, &tf_type));
119 info->set_dtype(tf_type);
120 info->set_name(tflite_tensor->name);
121 for (
int i = 0; i < tflite_tensor->dims->size; i++) {
122 info->mutable_tensor_shape()->add_dim()->set_size(
123 tflite_tensor->dims->data[i]);
125 return absl::OkStatus();
128 Status GetTensorInfoMap(
const tflite::Interpreter* interpreter,
bool input,
129 TensorInfoMap* infomap) {
130 const std::vector<int>& indices =
131 input ? interpreter->inputs() : interpreter->outputs();
132 const string& input_str = input ?
"Input" :
"Output";
133 for (
int index : indices) {
134 const TfLiteTensor* tensor = interpreter->tensor(index);
135 if (tensor->name ==
nullptr) {
136 return errors::Internal(input_str,
137 " name missing for tensor index: ", index);
140 TF_RETURN_IF_ERROR(TfLiteTensorToTensorInfo(tensor, &info));
141 if (!infomap->emplace(tensor->name, std::pair<TensorInfo, int>(info, index))
143 return errors::AlreadyExists(input_str,
" tensor name: ", tensor->name,
144 " has multiple indices");
147 return absl::OkStatus();
150 std::vector<int> TensorDims(
const Tensor& tensor) {
151 std::vector<int> dims(tensor.dims());
152 for (
int i = 0; i < tensor.dims(); ++i) {
153 dims[i] =
static_cast<int>(tensor.dim_size(i));
159 Status CreateOutputTensors(
160 std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
161 const std::vector<string>& output_tensor_names,
162 const std::map<string, int>& output_tensor_to_idx,
163 std::map<int32_t, Tensor*>& tflite_idx_to_output_tensor,
164 std::vector<Tensor>* output_tensors) {
165 output_tensors->reserve(output_tensor_names.size());
166 for (std::string tfname : output_tensor_names) {
167 auto fix_status = FixTfLiteTensorName(output_tensor_to_idx, tfname);
168 if (fix_status != absl::OkStatus()) {
169 return errors::Internal(
"Missing output TFLite tensor: ", tfname,
": ",
170 fix_status.message());
172 const int tflite_idx = output_tensor_to_idx.at(tfname);
173 TensorShape tf_shape;
174 const auto& interpreter = interpreter_wrapper->Get();
175 const auto* tflite_tensor = interpreter->tensor(tflite_idx);
176 for (
int i = 0; i < tflite_tensor->dims->size; ++i) {
177 tf_shape.AddDim(tflite_tensor->dims->data[i]);
180 TF_RETURN_IF_ERROR(TfLiteTypeToTfType(tflite_tensor->type, &tf_type));
181 output_tensors->emplace_back(tf_type, tf_shape);
182 tflite_idx_to_output_tensor[tflite_idx] = &output_tensors->back();
184 return absl::OkStatus();
187 Status SetInputAndInvokeMiniBatch(
188 std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
189 const std::vector<int>& tflite_input_indices,
190 const std::vector<std::vector<const Tensor*>>& inputs,
int batch_size,
191 int* fixed_batch_size) {
192 auto* interpreter = interpreter_wrapper->Get();
194 for (
int i = 0; i < tflite_input_indices.size(); ++i) {
195 int tflite_input_idx = tflite_input_indices[i];
196 auto tflite_input_tensor = interpreter->tensor(tflite_input_idx);
197 const auto& tf_input_tensors = inputs[i];
198 if (tflite_input_tensor->type != kTfLiteString) {
199 const Tensor* tf_input_tensor = tf_input_tensors[0];
202 if (tf_input_tensors.size() > 1) {
203 std::vector<Tensor> to_concatenate;
204 to_concatenate.reserve(tf_input_tensors.size());
205 for (
const auto* t : tf_input_tensors) {
206 to_concatenate.push_back(std::move(*t));
208 TF_RETURN_IF_ERROR(tensor::Concat(to_concatenate, &concated));
209 tf_input_tensor = &concated;
211 auto tensor_bytes = tf_input_tensor->tensor_data();
212 std::vector<int> tf_dims = TensorDims(*tf_input_tensor);
213 std::vector<int> tflite_dims(
214 tflite_input_tensor->dims->data,
215 tflite_input_tensor->dims->data + tflite_input_tensor->dims->size);
216 if (tensor_bytes.size() != tflite_input_tensor->bytes ||
217 tf_dims != tflite_dims) {
218 if (interpreter->ResizeInputTensor(tflite_input_idx, tf_dims) !=
220 return errors::Internal(
221 "Failed to resize input tensor: ", tflite_input_tensor->name,
222 " from ", tflite_input_tensor->bytes,
" to ", tensor_bytes.size(),
225 if (interpreter->AllocateTensors() != kTfLiteOk) {
226 return errors::Internal(
"Failed to allocate tensors");
229 std::memcpy(tflite_input_tensor->data.raw, tensor_bytes.data(),
230 tensor_bytes.size());
233 const bool needs_resize =
234 fixed_batch_size ? batch_size > interpreter_wrapper->GetBatchSize()
235 : batch_size != interpreter_wrapper->GetBatchSize();
238 interpreter->ResizeInputTensor(tflite_input_idx, {batch_size});
239 interpreter_wrapper->SetBatchSize(batch_size);
240 if (interpreter->AllocateTensors() != kTfLiteOk) {
241 return errors::Internal(
"Failed to allocate tensors");
244 if (fixed_batch_size) {
245 *fixed_batch_size = interpreter_wrapper->GetBatchSize();
247 TF_RETURN_IF_ERROR(interpreter_wrapper->SetStringData(
248 tf_input_tensors, tflite_input_tensor, tflite_input_idx, batch_size));
251 if (interpreter_wrapper->Invoke() != kTfLiteOk) {
252 return errors::Internal(
"Failed to invoke TfLite interpreter");
254 return absl::OkStatus();
257 Status SetMiniBatchOutput(
258 std::unique_ptr<internal::TfLiteInterpreterWrapper>& interpreter_wrapper,
259 const std::map<int, Tensor*>& tflite_idx_to_output_tensor,
260 std::vector<Tensor>* outputs) {
261 for (
const auto& entry : tflite_idx_to_output_tensor) {
262 Tensor* tensor = entry.second;
263 const DataType tf_type = tensor->dtype();
264 if (tensor->NumElements() == 0) {
267 const auto* interpreter = interpreter_wrapper->Get();
268 auto tflite_tensor = interpreter->tensor(entry.first);
269 if (DataTypeCanUseMemcpy(tf_type)) {
270 auto tensor_bytes = tensor->tensor_data();
272 size_t tflite_tensor_bytes = tflite_tensor->bytes;
273 std::memcpy(
const_cast<char*
>(tensor_bytes.data() + offset),
274 tflite_tensor->data.raw, tflite_tensor_bytes);
275 }
else if (tflite_tensor->type == kTfLiteString) {
276 const int string_count = tflite::GetStringCount(tflite_tensor);
277 int num_strings = string_count;
279 auto str_tensors = tensor->flat<tstring>();
280 for (
int i = 0; i < num_strings; i++) {
281 const auto& ref = tflite::GetString(tflite_tensor, i);
282 str_tensors(i + offset).assign(ref.str, ref.len);
286 return absl::OkStatus();
289 int GetModelBatchSize(
const tflite::Model* model) {
290 const auto* primary_subgraph = model->subgraphs()->Get(0);
291 const auto* inputs = primary_subgraph->inputs();
292 if (inputs->size() == 1) {
295 const int tensor_id = inputs->Get(0);
296 const auto* tensor = primary_subgraph->tensors()->Get(tensor_id);
297 return tensor->shape()->Get(0);
305 Status TfLiteSession::SplitTfLiteInputTask(
306 std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
307 int open_batch_remaining_slot,
int max_batch_size,
308 std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks) {
309 auto* input_task = input_task_ptr->get();
311 std::make_shared<std::vector<std::unique_ptr<std::vector<Tensor>>>>();
312 auto partial_status = std::make_shared<ThreadSafeStatus>();
313 auto split_task_done_callback = [split_output, partial_status, input_task]() {
315 auto cleanup = gtl::MakeCleanup([done_notification = input_task->done]() {
316 done_notification->Notify();
320 if (!partial_status->status().ok()) {
321 *input_task->status = partial_status->status();
326 int output_size = split_output->size();
328 int tensor_size = (*split_output)[0]->size();
331 for (
int tensor_idx = 0; tensor_idx < tensor_size; ++tensor_idx) {
332 Tensor output_tensor;
333 std::vector<Tensor> to_concatenate;
334 to_concatenate.reserve(output_size);
336 for (
int output_idx = 0; output_idx < output_size; ++output_idx) {
337 to_concatenate.push_back(
338 std::move((*(*split_output)[output_idx])[tensor_idx]));
340 const auto concat_status = tensor::Concat(to_concatenate, &output_tensor);
341 if (!concat_status.ok()) {
342 *input_task->status = concat_status;
346 input_task->outputs->push_back(output_tensor);
348 *input_task->status = absl::OkStatus();
352 IncrementalBarrier barrier(std::move(split_task_done_callback));
353 std::vector<int64_t> output_task_sizes;
355 if (open_batch_remaining_slot > 0) {
356 output_task_sizes.push_back(open_batch_remaining_slot);
357 split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
360 for (
int left_task_size = input_task->size() - open_batch_remaining_slot;
361 left_task_size > 0; left_task_size -= max_batch_size) {
362 int next_task_size = std::min(left_task_size, max_batch_size);
363 output_task_sizes.push_back(next_task_size);
364 split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
367 const int output_task_num = output_task_sizes.size();
368 output_tasks->reserve(output_task_num);
369 for (
int i = 0; i < output_task_num; ++i) {
370 std::unique_ptr<TfLiteBatchTask> task;
371 TfLiteBatchTask::CreatePartialTfLiteBatchTask(
372 input_task->input_indices, input_task->output_tensor_names,
373 (*split_output)[i].get(), barrier.Inc(), partial_status.get(), &task);
374 output_tasks->push_back(std::move(task));
377 for (
int i = 0; i < input_task->inputs.size(); ++i) {
378 const Tensor& input = input_task->inputs[i];
379 std::vector<Tensor> split_tensors;
380 auto status = tensor::Split(input, output_task_sizes, &split_tensors);
381 if (status != absl::OkStatus()) {
384 for (
int output_idx = 0; output_idx < output_task_num; ++output_idx) {
385 auto& output_task = (*output_tasks)[output_idx];
386 output_task->inputs.push_back(std::move(split_tensors[output_idx]));
389 return absl::OkStatus();
392 Status TfLiteSession::CreateDefaultBasicBatchScheduler(
393 const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
394 std::function<
void(std::unique_ptr<Batch<TfLiteBatchTask>>)>
395 process_batch_callback,
396 std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>* batch_scheduler) {
397 std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>> basic_batch_scheduler;
398 TF_RETURN_IF_ERROR(BasicBatchScheduler<TfLiteBatchTask>::Create(
399 options, process_batch_callback, &basic_batch_scheduler));
400 *batch_scheduler = std::move(basic_batch_scheduler);
401 return absl::OkStatus();
404 Status TfLiteSession::SetScheduler(
405 const SchedulerCreator& scheduler_creator,
406 const BasicBatchScheduler<TfLiteBatchTask>::Options& options) {
407 use_fixed_batch_size_ =
true;
408 scheduler_options_ = options;
409 auto bound_scheduler_creator = absl::bind_front(
410 &TfLiteSession::CreateDefaultBasicBatchScheduler, scheduler_options_);
411 return bound_scheduler_creator(
412 [
this](std::unique_ptr<Batch<TfLiteBatchTask>> batch) {
413 this->ProcessBatch(std::move(batch));
418 Status TfLiteSession::Create(
string&& buffer,
const SessionOptions& options,
419 int num_pools,
int num_interpreters_per_pool,
420 std::unique_ptr<TfLiteSession>* tflite_session,
421 ::google::protobuf::Map<string, SignatureDef>* signatures) {
422 auto model = tflite::FlatBufferModel::BuildFromModel(
423 flatbuffers::GetRoot<tflite::Model>(buffer.data()));
424 if (model ==
nullptr) {
425 return errors::InvalidArgument(
"Cannot build FlatBufferModel from buffer.");
428 tflite::ops::builtin::BuiltinOpResolver resolver;
429 tflite::ops::custom::AddParseExampleOp(&resolver);
431 std::unique_ptr<tflite::Interpreter> interpreter;
432 if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
433 return errors::Internal(
"Cannot build Interpreter from buffer.");
436 TensorInfoMap inputs;
437 TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(),
true, &inputs));
438 TensorInfoMap outputs;
439 TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(),
false, &outputs));
442 std::map<string, int> input_tensor_to_index;
443 std::map<string, int> output_tensor_to_index;
444 for (
const auto& info : inputs) {
445 const string& tflite_tensor_name = info.first;
446 input_tensor_to_index[tflite_tensor_name] = info.second.second;
448 for (
const auto& info : outputs) {
449 const string& tflite_tensor_name = info.first;
450 output_tensor_to_index[tflite_tensor_name] = info.second.second;
454 std::map<string, SignatureDef> signature_defs;
456 tflite::GetSignatureDefMap(model->GetModel(), &signature_defs);
457 if (status != absl::OkStatus()) {
458 return errors::InvalidArgument(
459 "Invalid SignatureDefs found in TfLite model: ", status.message());
461 const bool has_lite_signature_def = !signature_defs.empty();
464 if (has_lite_signature_def) {
468 for (
const auto& signature_item : signature_defs) {
469 SignatureDef* tflite_signature = &(*signatures)[signature_item.first];
470 tflite_signature->CopyFrom(signature_item.second);
471 for (
auto& input : *tflite_signature->mutable_inputs()) {
472 TensorInfo* tensor_info = &input.second;
473 TF_RETURN_WITH_CONTEXT_IF_ERROR(
474 FixTfLiteTensorName(input_tensor_to_index,
475 *tensor_info->mutable_name()),
476 "Signature input ", input.first,
" references an unknown tensor");
478 for (
auto& output : *tflite_signature->mutable_outputs()) {
479 TensorInfo* tensor_info = &output.second;
480 TF_RETURN_WITH_CONTEXT_IF_ERROR(
481 FixTfLiteTensorName(output_tensor_to_index,
482 *tensor_info->mutable_name()),
483 "Signature output ", output.first,
" references an unknown tensor");
489 LOG(WARNING) <<
"No signature def found in TFLite model. Generating one.";
490 SignatureDef* sigdef = &(*signatures)[kDefaultServingSignatureDefKey];
491 for (
const auto& info : inputs) {
492 string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
493 (*sigdef->mutable_inputs())[tflite_tensor_name] = info.second.first;
495 for (
const auto& info : outputs) {
496 string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
497 (*sigdef->mutable_outputs())[tflite_tensor_name] = info.second.first;
499 sigdef->set_method_name(kPredictMethodName);
502 const int num_interpreters = std::max(1, num_pools);
503 const int model_batch_size = GetModelBatchSize(model->GetModel());
505 std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool;
507 internal::TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
508 model.get(), options, num_interpreters, interpreter_pool));
510 tflite_session->reset(
new TfLiteSession(
511 std::move(input_tensor_to_index), std::move(output_tensor_to_index),
512 std::move(buffer), std::move(model), std::move(interpreter_pool)));
514 if (num_interpreters_per_pool > 1) {
515 const int default_allowed_batch =
516 (internal::kInitialBatchSize + num_interpreters_per_pool - 1) /
517 num_interpreters_per_pool;
518 const int min_allowed_batch =
519 model_batch_size > 1 ? model_batch_size : default_allowed_batch;
520 const int max_enqueued_batches = num_interpreters * 100;
521 BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options;
522 scheduler_options.num_batch_threads = num_interpreters;
523 scheduler_options.max_batch_size = internal::kInitialBatchSize;
524 scheduler_options.enable_large_batch_splitting =
true;
525 scheduler_options.max_execution_batch_size = min_allowed_batch;
526 scheduler_options.max_enqueued_batches = max_enqueued_batches;
527 scheduler_options.split_input_task_func = SplitTfLiteInputTask;
530 ->SetScheduler(&TfLiteSession::CreateDefaultBasicBatchScheduler,
533 return absl::OkStatus();
536 TfLiteSession::TfLiteSession(
537 std::map<string, int>&& input_tensor_to_index,
538 std::map<string, int>&& output_tensor_to_index,
string&& buffer,
539 std::unique_ptr<tflite::FlatBufferModel> model,
540 std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool)
541 : input_tensor_to_index_(std::move(input_tensor_to_index)),
542 output_tensor_to_index_(std::move(output_tensor_to_index)),
543 model_serialized_bytes_(std::move(buffer)),
544 model_(std::move(model)),
545 interpreter_pool_(std::move(interpreter_pool)) {}
547 Status TfLiteSession::Run(
const std::vector<std::pair<string, Tensor>>& inputs,
548 const std::vector<string>& output_tensor_names,
549 const std::vector<string>& target_node_names,
550 std::vector<Tensor>* outputs) {
551 RunMetadata run_metadata;
552 return Run(RunOptions(), inputs, output_tensor_names, target_node_names,
553 outputs, &run_metadata);
556 Status TfLiteSession::Run(
const RunOptions& run_options,
557 const std::vector<std::pair<string, Tensor>>& inputs,
558 const std::vector<string>& output_tensor_names,
559 const std::vector<string>& target_node_names,
560 std::vector<Tensor>* outputs,
561 RunMetadata* run_metadata) {
562 return Run(run_options, inputs, output_tensor_names, target_node_names,
563 outputs, run_metadata, thread::ThreadPoolOptions());
566 Status TfLiteSession::RunInternal(
567 const std::vector<int>& tflite_input_indices,
568 const std::vector<std::vector<const Tensor*>>& merged_inputs,
569 const std::vector<string>& output_tensor_names,
570 std::vector<Tensor>* combined_outputs,
int batch_size,
571 int* fixed_batch_size) {
572 #define RETURN_POOL_IF_ERROR(...) \
574 ::tensorflow::Status _status = (__VA_ARGS__); \
575 if (TF_PREDICT_FALSE(!_status.ok())) { \
576 interpreter_pool_->ReturnInterpreter(std::move(interpreter)); \
580 auto interpreter = interpreter_pool_->GetInterpreter();
581 RETURN_POOL_IF_ERROR(
582 SetInputAndInvokeMiniBatch(interpreter, tflite_input_indices,
583 merged_inputs, batch_size, fixed_batch_size));
587 std::map<int32_t, Tensor*> tflite_idx_to_output_tensor;
588 RETURN_POOL_IF_ERROR(CreateOutputTensors(
589 interpreter, output_tensor_names, output_tensor_to_index_,
590 tflite_idx_to_output_tensor, combined_outputs));
593 RETURN_POOL_IF_ERROR(SetMiniBatchOutput(
594 interpreter, tflite_idx_to_output_tensor, combined_outputs));
596 #undef RETURN_POOL_IF_ERROR
597 interpreter_pool_->ReturnInterpreter(std::move(interpreter));
598 return absl::OkStatus();
601 Status TfLiteSession::Run(
602 const RunOptions& run_options,
603 const std::vector<std::pair<string, Tensor>>& inputs,
604 const std::vector<string>& output_tensor_names,
605 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
606 RunMetadata* run_metadata,
607 const thread::ThreadPoolOptions& thread_pool_options) {
608 std::map<int, const Tensor*> tflite_idx_to_input_tensor;
609 for (
const auto& input : inputs) {
610 string name = input.first;
611 TF_RETURN_WITH_CONTEXT_IF_ERROR(
612 FixTfLiteTensorName(input_tensor_to_index_, name),
613 "Missing input TFLite tensor: ", name);
614 const int index = input_tensor_to_index_.at(name);
615 tflite_idx_to_input_tensor[index] = &input.second;
617 outputs->reserve(output_tensor_names.size());
619 std::vector<int> input_indices;
620 std::vector<std::vector<const Tensor*>> inputs;
621 for (
const auto entry : tflite_idx_to_input_tensor) {
622 const auto& tf_tensor = *entry.second;
623 inputs.push_back({&tf_tensor});
624 input_indices.push_back(entry.first);
626 const int batch_size =
627 inputs.empty() || inputs[0].empty() ? 1 : inputs[0][0]->dim_size(0);
628 return RunInternal(input_indices, inputs, output_tensor_names, outputs,
633 std::unique_ptr<TfLiteBatchTask> task;
634 TfLiteBatchTask::CreateTfLiteBatchTask(&output_tensor_names, outputs, &done,
636 for (
const auto entry : tflite_idx_to_input_tensor) {
637 task->input_indices.push_back(entry.first);
638 task->inputs.push_back(std::move(*entry.second));
640 TF_RETURN_IF_ERROR(scheduler_->Schedule(&task));
641 done.WaitForNotification();
645 Status TfLiteSession::ListDevices(std::vector<DeviceAttributes>* response) {
646 return errors::Unimplemented(
"ListDevices is not yet supported.");
649 Status MergeInputTensors(
const Batch<TfLiteBatchTask>& batch,
650 std::vector<std::vector<const Tensor*>>* merged_inputs,
652 if (batch.num_tasks() < 1) {
653 return errors::Internal(
"Batch size expected to be positive; was ",
656 const int tensors_per_task = batch.task(0).inputs.size();
659 for (
int i = 0; i < tensors_per_task; ++i) {
660 merged_inputs->emplace_back();
661 std::vector<const Tensor*>& tensors_to_merge = merged_inputs->back();
662 for (
int j = 0; j < batch.num_tasks(); ++j) {
663 const std::vector<Tensor>& inputs = batch.task(j).inputs;
664 tensors_to_merge.push_back(&(inputs[i]));
666 if (inputs[i].dims()) {
667 *batch_size += inputs[i].dim_size(0);
672 return absl::OkStatus();
675 Status SplitOutputTensors(
const std::vector<Tensor>& combined_outputs,
676 Batch<TfLiteBatchTask>* batch,
int batch_size) {
677 std::vector<int64_t> task_sizes(batch->num_tasks());
679 for (
int i = 0; i < batch->num_tasks(); ++i) {
680 const int task_size = batch->task(i).size();
681 task_sizes[i] = task_size;
682 total_size += task_size;
685 if (total_size < batch_size) {
686 task_sizes.push_back(batch_size - total_size);
689 for (
int i = 0; i < combined_outputs.size(); i++) {
690 const auto& output_tensor = combined_outputs[i];
691 std::vector<Tensor> split_tensor;
692 const Status split_status =
693 tensor::Split(output_tensor, task_sizes, &split_tensor);
694 if (!split_status.ok()) {
695 return errors::Internal(
"Tensor split operation failed: ",
696 split_status.ToString());
698 for (
int j = 0; j < batch->num_tasks(); ++j) {
699 TfLiteBatchTask& task = *(batch->mutable_task(j));
700 task.set_output(split_tensor[j]);
704 return absl::OkStatus();
707 void TfLiteSession::ProcessBatch(
708 std::unique_ptr<Batch<TfLiteBatchTask>> batch) {
712 batch->WaitUntilClosed();
714 if (batch->empty()) {
718 const uint64_t dequeue_time_micros = EnvTime::NowMicros();
724 auto finally = gtl::MakeCleanup([&status, &batch] {
725 for (
int i = 0; i < batch->num_tasks(); ++i) {
726 TfLiteBatchTask* task = batch->mutable_task(i);
727 if (task->is_partial) {
728 task->partial_status->Update(status);
729 task->done_callback();
731 *batch->mutable_task(i)->status = status;
732 batch->mutable_task(i)->done->Notify();
740 bool all_tasks_timeout_exceeded =
true;
741 uint64_t batch_deadline_micros = 0;
742 for (
int i = 0; i < batch->num_tasks(); ++i) {
743 const TfLiteBatchTask& task = batch->task(i);
746 if (task.run_options.timeout_in_ms() <= 0) {
747 all_tasks_timeout_exceeded =
false;
750 const int64_t task_timeout_micros = task.run_options.timeout_in_ms() * 1000;
751 const uint64_t task_deadline_micros =
752 task.enqueue_time_micros + task_timeout_micros;
753 if (task_deadline_micros > dequeue_time_micros) {
754 all_tasks_timeout_exceeded =
false;
755 if (task_deadline_micros > batch_deadline_micros) {
756 batch_deadline_micros = task_deadline_micros;
760 if (all_tasks_timeout_exceeded) {
761 status = Status(
static_cast<tensorflow::errors::Code
>(
762 absl::StatusCode::kResourceExhausted),
763 "Run() timeout exceeded while waiting in batching queue");
767 std::vector<std::vector<const Tensor*>> merged_inputs;
769 status = MergeInputTensors(*batch, &merged_inputs, &batch_size);
773 std::vector<Tensor> combined_outputs;
774 const auto& tflite_input_indices = batch->task(0).input_indices;
775 auto& output_tensor_names = batch->task(0).output_tensor_names;
776 int fixed_batch_size = batch_size;
777 status = RunInternal(tflite_input_indices, merged_inputs,
778 *output_tensor_names, &combined_outputs, batch_size,
779 use_fixed_batch_size_ ? &fixed_batch_size :
nullptr);
784 status = SplitOutputTensors(combined_outputs, batch.get(), fixed_batch_size);