16 #include "tensorflow_serving/servables/tensorflow/tflite_session.h"
26 #include <gtest/gtest.h>
27 #include "absl/flags/flag.h"
28 #include "absl/functional/bind_front.h"
29 #include "flatbuffers/flexbuffers.h"
30 #include "tensorflow/cc/saved_model/signature_constants.h"
31 #include "tensorflow/core/example/example.pb.h"
32 #include "tensorflow/core/example/feature.pb.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/lib/core/status_test_util.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/core/platform/strcat.h"
42 #include "tensorflow/core/platform/test_benchmark.h"
43 #include "tensorflow/core/platform/threadpool_options.h"
44 #include "tensorflow/core/protobuf/config.pb.h"
45 #include "tensorflow/core/protobuf/error_codes.pb.h"
46 #include "tensorflow/lite/string_util.h"
47 #include "tensorflow/lite/tools/signature/signature_def_util.h"
48 #include "tensorflow/lite/util.h"
49 #include "tensorflow/lite/version.h"
50 #include "tensorflow_serving/test_util/test_util.h"
52 ABSL_FLAG(
int, num_pools, 1,
"Number of interpreter pools of a TfLiteSession.");
54 ABSL_FLAG(
int, num_tflite_interpreters, 1,
55 "Number of TFLite interpreters "
56 "in an interpreter pool of a TfLiteSession.");
58 namespace tensorflow {
63 using ::testing::Pair;
64 using ::testing::SizeIs;
65 using ::testing::UnorderedElementsAre;
67 constexpr
char kTestModel[] =
68 "/servables/tensorflow/testdata/saved_model_half_plus_two_tflite/00000123/"
71 constexpr
char kTestModelWithSigdef[] =
72 "/servables/tensorflow/testdata/"
73 "saved_model_half_plus_two_tflite_with_sigdef/00000123/model.tflite";
75 constexpr
char kMobileNetModel[] =
76 "/servables/tensorflow/testdata/mobilenet_v1_quant_tflite/00000123/"
79 constexpr
char kParseExampleModel[] =
80 "/servables/tensorflow/testdata/parse_example_tflite/00000123/"
83 TEST(TfLiteSession, BasicTest) {
85 TF_ASSERT_OK(ReadFileToString(tensorflow::Env::Default(),
86 test_util::TestSrcDirPath(kTestModel),
89 ::google::protobuf::Map<string, SignatureDef> signatures;
90 std::unique_ptr<TfLiteSession> session;
91 tensorflow::SessionOptions options;
92 TF_ASSERT_OK(TfLiteSession::Create(
93 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
94 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
95 EXPECT_EQ(signatures.size(), 1);
96 EXPECT_EQ(signatures.begin()->first,
"serving_default");
97 EXPECT_THAT(signatures.begin()->second, test_util::EqualsProto(R
"(
120 method_name: "tensorflow/serving/predict"
122 Tensor input = test::AsTensor<float>({1.0, 2.0, 3.0}, TensorShape({3}));
125 std::vector<Tensor> outputs;
126 TF_EXPECT_OK(session->Run({{
"x", input}}, {
"y"}, {}, &outputs));
127 ASSERT_EQ(outputs.size(), 1);
128 test::ExpectTensorEqual<float>(
129 outputs[0], test::AsTensor<float>({2.5, 3, 3.5}, TensorShape({3})));
133 std::vector<Tensor> outputs;
134 TF_EXPECT_OK(session->Run({{
"x:0", input}}, {
"y:0"}, {}, &outputs));
135 ASSERT_EQ(outputs.size(), 1);
136 test::ExpectTensorEqual<float>(
137 outputs[0], test::AsTensor<float>({2.5, 3, 3.5}, TensorShape({3})));
141 TEST(TfLiteSession, ResizeWithSameNumElementsTest) {
143 TF_ASSERT_OK(ReadFileToString(tensorflow::Env::Default(),
144 test_util::TestSrcDirPath(kTestModel),
147 ::google::protobuf::Map<string, SignatureDef> signatures;
148 std::unique_ptr<TfLiteSession> session;
149 tensorflow::SessionOptions options;
150 TF_ASSERT_OK(TfLiteSession::Create(
151 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
152 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
153 EXPECT_EQ(signatures.size(), 1);
154 EXPECT_EQ(signatures.begin()->first,
"serving_default");
155 EXPECT_THAT(signatures.begin()->second, test_util::EqualsProto(R
"pb(
178 method_name: "tensorflow/serving/predict"
180 Tensor input = test::AsTensor<float>({2.0}, TensorShape({1}));
183 std::vector<Tensor> outputs;
184 TF_EXPECT_OK(session->Run({{
"x", input}}, {
"y"}, {}, &outputs));
185 ASSERT_EQ(outputs.size(), 1);
186 test::ExpectTensorEqual<float>(
187 outputs[0], test::AsTensor<float>({3.0}, TensorShape({1})));
191 TEST(TfLiteSession, ModelFromLegacyConverterWithSigdef) {
196 TF_ASSERT_OK(ReadFileToString(tensorflow::Env::Default(),
197 test_util::TestSrcDirPath(kTestModelWithSigdef),
200 ::google::protobuf::Map<string, SignatureDef> signatures;
201 std::unique_ptr<TfLiteSession> session;
202 tensorflow::SessionOptions options;
203 TF_ASSERT_OK(TfLiteSession::Create(
204 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
205 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
206 EXPECT_EQ(signatures.size(), 1);
207 EXPECT_EQ(signatures.begin()->first,
"serving_default");
210 EXPECT_THAT(signatures.begin()->second, test_util::EqualsProto(R
"(
233 method_name: "tensorflow/serving/predict"
236 Tensor input = test::AsTensor<float>({1.0, 2.0, 3.0}, TensorShape({3}));
239 std::vector<Tensor> outputs;
240 TF_EXPECT_OK(session->Run({{
"tflite_input", input}}, {
"y"}, {}, &outputs));
241 ASSERT_EQ(outputs.size(), 1);
242 test::ExpectTensorEqual<float>(
243 outputs[0], test::AsTensor<float>({2.5, 3, 3.5}, TensorShape({3})));
247 std::vector<Tensor> outputs;
249 session->Run({{
"tflite_input:0", input}}, {
"y:0"}, {}, &outputs));
250 ASSERT_EQ(outputs.size(), 1);
251 test::ExpectTensorEqual<float>(
252 outputs[0], test::AsTensor<float>({2.5, 3, 3.5}, TensorShape({3})));
256 constexpr
char kTestModelInputList[] =
"list";
257 constexpr
char kTestModelInputShape[] =
"shape";
258 constexpr
char kTestModelOutput[] =
"output";
260 constexpr
char kSignatureInputList[] =
"input_list";
261 constexpr
char kSignatureInputShape[] =
"input_shape";
262 constexpr
char kSignatureOutput[] =
"sigdef_output";
264 constexpr
int kBatchSize = 500;
266 std::map<string, SignatureDef> GetTestSignatureDefMap() {
267 auto signature_def = SignatureDef();
268 TensorInfo input_list_tensor;
269 TensorInfo input_shape_tensor;
270 TensorInfo output_tensor;
271 *input_list_tensor.mutable_name() = absl::StrCat(kTestModelInputList,
":0");
272 *input_shape_tensor.mutable_name() = absl::StrCat(kTestModelInputShape,
":0");
273 *output_tensor.mutable_name() = absl::StrCat(kTestModelOutput,
":0");
274 *signature_def.mutable_method_name() = kClassifyMethodName;
275 (*signature_def.mutable_inputs())[kSignatureInputList] = input_list_tensor;
276 (*signature_def.mutable_inputs())[kSignatureInputShape] = input_shape_tensor;
277 (*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
278 std::map<string, SignatureDef> signature_def_map = {
279 {kDefaultServingSignatureDefKey, signature_def}};
280 return signature_def_map;
283 tensorflow::DataType ToTfTensorType(tflite::TensorType tflite_type) {
284 switch (tflite_type) {
285 case tflite::TensorType_INT32:
286 return tensorflow::DT_INT32;
287 case tflite::TensorType_STRING:
288 return tensorflow::DT_STRING;
290 LOG(FATAL) <<
"Unsupported tflite type: " << tflite_type;
294 string BuildTestModel(tflite::TensorType tensor_type,
295 const string& input1_tensor_name,
296 const string& input2_tensor_name,
bool use_flex_op,
297 std::map<string, SignatureDef>* signature_def_map) {
298 std::vector<int32_t> inputs;
299 std::vector<int32_t> outputs;
300 std::vector<flatbuffers::Offset<tflite::Tensor>> tensors;
301 std::vector<flatbuffers::Offset<tflite::OperatorCode>> opcodes;
302 std::vector<flatbuffers::Offset<tflite::Operator>> operators;
303 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
304 flatbuffers::FlatBufferBuilder builder;
307 inputs.push_back(tensors.size());
308 tensors.push_back(CreateTensor(builder, builder.CreateVector<
int>({1}),
310 builder.CreateString(input1_tensor_name),
314 inputs.push_back(tensors.size());
315 tensors.push_back(CreateTensor(builder, builder.CreateVector<
int>({1}),
316 tflite::TensorType_INT32, 0,
317 builder.CreateString(input2_tensor_name),
321 outputs.push_back(tensors.size());
322 tensors.push_back(CreateTensor(builder, builder.CreateVector<
int>({1}),
324 builder.CreateString(kTestModelOutput),
328 tflite::BuiltinOptions builtin_opts_type =
329 tflite::BuiltinOptions_ReshapeOptions;
330 flatbuffers::Offset<void> reshape_opts =
331 tflite::CreateReshapeOptions(builder, builder.CreateVector<
int>({}))
333 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_opts = 0;
335 string flexop = std::string(tflite::kFlexCustomCodePrefix) +
"Reshape";
336 opcodes.push_back(CreateOperatorCodeDirect(
337 builder, tflite::BuiltinOperator_CUSTOM, flexop.data()));
338 builtin_opts_type = tflite::BuiltinOptions_NONE;
341 node_def.set_name(
"Reshape");
342 node_def.set_op(
"Reshape");
343 (*node_def.mutable_attr())[
"T"].set_type(ToTfTensorType(tensor_type));
345 CHECK(node_def.SerializeToString(&node_def_str));
346 auto flex_builder = absl::make_unique<flexbuffers::Builder>();
347 flex_builder->Vector([&]() {
348 flex_builder->String(node_def.op());
349 flex_builder->String(node_def_str);
351 flex_builder->Finish();
352 custom_opts = builder.CreateVector(flex_builder->GetBuffer());
355 CreateOperatorCode(builder, tflite::BuiltinOperator_RESHAPE, 0));
358 operators.push_back(CreateOperator(
359 builder, 0, builder.CreateVector<int32_t>(inputs),
360 builder.CreateVector<int32_t>(outputs), builtin_opts_type, reshape_opts,
361 custom_opts, tflite::CustomOptionsFormat_FLEXBUFFERS));
363 auto subgraph = CreateSubGraph(builder, builder.CreateVector(tensors),
364 builder.CreateVector<int32_t>(inputs),
365 builder.CreateVector<int32_t>(outputs),
366 builder.CreateVector(operators));
367 builder.Finish(CreateModel(
368 builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(opcodes),
369 builder.CreateVector(&subgraph, 1), builder.CreateString(
"testmodel"),
370 builder.CreateVector(buffers)));
372 if (signature_def_map) {
373 std::string model_buffer = string(
374 reinterpret_cast<char*
>(builder.GetBufferPointer()), builder.GetSize());
375 std::string model_buffer_with_signature_def;
376 auto model = tflite::FlatBufferModel::BuildFromModel(
377 flatbuffers::GetRoot<tflite::Model>(model_buffer.data()));
378 TF_CHECK_OK(tflite::SetSignatureDefMap(model->GetModel(),
380 &model_buffer_with_signature_def));
381 return model_buffer_with_signature_def;
384 return string(
reinterpret_cast<char*
>(builder.GetBufferPointer()),
397 string BuildTestModel(tflite::TensorType tensor_type,
bool use_flex_op,
398 std::map<string, SignatureDef>* signature_def_map) {
399 return BuildTestModel(tensor_type, kTestModelInputList, kTestModelInputShape,
400 use_flex_op, signature_def_map);
403 TEST(TfLiteSession, ProcessStrings) {
404 auto model_signature_def_map = GetTestSignatureDefMap();
406 BuildTestModel(tflite::TensorType_STRING,
false,
407 &model_signature_def_map);
408 ::google::protobuf::Map<string, SignatureDef> signatures;
409 std::unique_ptr<TfLiteSession> session;
410 tensorflow::SessionOptions options;
411 TF_ASSERT_OK(TfLiteSession::Create(
412 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
413 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
415 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({4}));
416 Tensor input_shape = test::AsTensor<int32>({2, 2}, TensorShape({2}));
417 std::vector<Tensor> outputs;
418 TF_EXPECT_OK(session->Run(
419 {{kTestModelInputList, input_list}, {kTestModelInputShape, input_shape}},
420 {kTestModelOutput}, {}, &outputs));
421 ASSERT_EQ(outputs.size(), 1);
422 test::ExpectTensorEqual<tstring>(
424 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({2, 2})));
427 TEST(TfLiteSession, ProcessStringsFlex) {
428 auto model_signature_def_map = GetTestSignatureDefMap();
430 BuildTestModel(tflite::TensorType_STRING,
true,
431 &model_signature_def_map);
432 ::google::protobuf::Map<string, SignatureDef> signatures;
433 std::unique_ptr<TfLiteSession> session;
434 tensorflow::SessionOptions options;
435 TF_ASSERT_OK(TfLiteSession::Create(
436 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
437 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
439 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({4}));
440 Tensor input_shape = test::AsTensor<int32>({2, 2}, TensorShape({2}));
441 std::vector<Tensor> outputs;
442 TF_EXPECT_OK(session->Run(
443 {{kTestModelInputList, input_list}, {kTestModelInputShape, input_shape}},
444 {kTestModelOutput}, {}, &outputs));
445 ASSERT_EQ(outputs.size(), 1);
446 test::ExpectTensorEqual<tstring>(
448 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({2, 2})));
451 TEST(TfLiteSession, ThreadPoolOptions) {
452 auto model_signature_def_map = GetTestSignatureDefMap();
454 BuildTestModel(tflite::TensorType_STRING,
false,
455 &model_signature_def_map);
456 ::google::protobuf::Map<string, SignatureDef> signatures;
457 std::unique_ptr<TfLiteSession> session;
458 tensorflow::SessionOptions options;
459 TF_ASSERT_OK(TfLiteSession::Create(
460 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
461 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
463 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({4}));
464 Tensor input_shape = test::AsTensor<int32>({2, 2}, TensorShape({2}));
465 std::vector<Tensor> outputs;
466 RunMetadata run_metadata;
467 thread::ThreadPoolOptions thread_pool_options;
468 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
470 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
472 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
473 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
474 TF_EXPECT_OK(session->Run(
476 {{kTestModelInputList, input_list}, {kTestModelInputShape, input_shape}},
477 {kTestModelOutput}, {}, &outputs, &run_metadata, thread_pool_options));
478 ASSERT_EQ(outputs.size(), 1);
479 test::ExpectTensorEqual<tstring>(
481 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({2, 2})));
483 EXPECT_EQ(inter_op_threadpool.NumScheduled(), 0);
484 EXPECT_EQ(intra_op_threadpool.NumScheduled(), 0);
487 TEST(TfLiteSession, SimpleSignatureDef) {
488 auto model_signature_def_map = GetTestSignatureDefMap();
490 BuildTestModel(tflite::TensorType_STRING,
false,
491 &model_signature_def_map);
493 ::google::protobuf::Map<string, SignatureDef> signatures;
495 string kResidualSignatureKey =
"residual_signature";
496 signatures[kResidualSignatureKey] = SignatureDef();
498 std::unique_ptr<TfLiteSession> session;
499 tensorflow::SessionOptions options;
500 TF_ASSERT_OK(TfLiteSession::Create(
501 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
502 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
504 ASSERT_THAT(signatures,
505 UnorderedElementsAre(Pair(kDefaultServingSignatureDefKey, _)));
507 auto sigdef = signatures[kDefaultServingSignatureDefKey];
508 EXPECT_EQ(sigdef.inputs().at(kSignatureInputList).name(),
509 kTestModelInputList);
510 EXPECT_EQ(sigdef.inputs().at(kSignatureInputShape).name(),
511 kTestModelInputShape);
512 EXPECT_EQ(sigdef.outputs().at(kSignatureOutput).name(), kTestModelOutput);
513 EXPECT_EQ(sigdef.method_name(), kClassifyMethodName);
516 TEST(TfLiteSession, MultipleSignatureDef) {
517 TensorInfo input_list_tensor;
518 TensorInfo input_shape_tensor;
519 TensorInfo output_tensor;
520 *input_list_tensor.mutable_name() = kTestModelInputList;
521 *input_shape_tensor.mutable_name() = kTestModelInputShape;
522 *output_tensor.mutable_name() = kTestModelOutput;
523 SignatureDef signature1 = SignatureDef();
524 *signature1.mutable_method_name() = kClassifyMethodName;
525 (*signature1.mutable_inputs())[kSignatureInputList] = input_list_tensor;
526 (*signature1.mutable_outputs())[kSignatureOutput] = output_tensor;
527 SignatureDef signature2 = SignatureDef();
528 *signature2.mutable_method_name() = kClassifyMethodName;
529 (*signature2.mutable_inputs())[kSignatureInputShape] = input_shape_tensor;
530 (*signature2.mutable_outputs())[kSignatureOutput] = output_tensor;
531 constexpr
char kSignatureKey1[] =
"signature1";
532 constexpr
char kSignatureKey2[] =
"signature2";
533 std::map<string, SignatureDef> signature_def_map = {
534 {kSignatureKey1, signature1}, {kSignatureKey2, signature2}};
536 string model_bytes = BuildTestModel(
537 tflite::TensorType_STRING,
false, &signature_def_map);
538 ::google::protobuf::Map<string, SignatureDef> signatures;
539 std::unique_ptr<TfLiteSession> session;
540 tensorflow::SessionOptions options;
541 TF_EXPECT_OK(TfLiteSession::Create(
542 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
543 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
545 ASSERT_THAT(signatures, UnorderedElementsAre(Pair(kSignatureKey1, _),
546 Pair(kSignatureKey2, _)));
547 auto result_signature1 = signatures[kSignatureKey1];
548 EXPECT_THAT(result_signature1.inputs().at(kSignatureInputList).name(),
549 kTestModelInputList);
550 EXPECT_EQ(result_signature1.outputs().at(kSignatureOutput).name(),
552 EXPECT_EQ(result_signature1.method_name(), kClassifyMethodName);
553 auto result_signature2 = signatures[kSignatureKey2];
554 EXPECT_EQ(result_signature2.inputs().at(kSignatureInputShape).name(),
555 kTestModelInputShape);
556 EXPECT_EQ(result_signature2.outputs().at(kSignatureOutput).name(),
558 EXPECT_EQ(result_signature2.method_name(), kClassifyMethodName);
561 TEST(TfLiteSession, SignatureDefWithCommonTensorPrefix) {
566 SignatureDef signature;
567 protobuf::TextFormat::ParseFromString(R
"(
573 tensor_shape { dim { size: 1 } }
581 tensor_shape { dim { size: 1 } }
584 method_name: "tensorflow/serving/predict"
587 std::map<string, SignatureDef> signature_def_map = {
588 {kDefaultServingSignatureDefKey, signature}};
590 BuildTestModel(tflite::TensorType_STRING,
"myTensor:0",
"myTensor:1",
591 false, &signature_def_map);
592 ::google::protobuf::Map<string, SignatureDef> signatures;
593 std::unique_ptr<TfLiteSession> session;
594 tensorflow::SessionOptions options;
595 TF_ASSERT_OK(TfLiteSession::Create(
596 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
597 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
600 auto outputSigdef = signatures[kDefaultServingSignatureDefKey];
601 std::set<string> tensorNamesSet;
602 for (
const auto& input : outputSigdef.inputs()) {
603 tensorNamesSet.insert(input.second.name());
605 EXPECT_THAT(tensorNamesSet, SizeIs(2));
608 TEST(TfLiteSession, SimpleSignatureDefAndRun) {
609 auto model_signature_def_map = GetTestSignatureDefMap();
611 BuildTestModel(tflite::TensorType_STRING,
false,
612 &model_signature_def_map);
613 ::google::protobuf::Map<string, SignatureDef> signatures;
614 std::unique_ptr<TfLiteSession> session;
615 tensorflow::SessionOptions options;
616 TF_EXPECT_OK(TfLiteSession::Create(
617 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
618 absl::GetFlag(FLAGS_num_tflite_interpreters), &session, &signatures));
620 auto sigdef = signatures[kDefaultServingSignatureDefKey];
621 ASSERT_EQ(sigdef.inputs().at(kSignatureInputList).name(),
622 kTestModelInputList);
623 ASSERT_EQ(sigdef.inputs().at(kSignatureInputShape).name(),
624 kTestModelInputShape);
625 ASSERT_EQ(sigdef.outputs().at(kSignatureOutput).name(), kTestModelOutput);
626 ASSERT_EQ(sigdef.method_name(), kClassifyMethodName);
629 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({4}));
630 Tensor input_shape = test::AsTensor<int32>({2, 2}, TensorShape({2}));
631 std::vector<Tensor> outputs;
632 TF_EXPECT_OK(session->Run(
633 {{kTestModelInputList, input_list}, {kTestModelInputShape, input_shape}},
634 {kTestModelOutput}, {}, &outputs));
635 ASSERT_EQ(outputs.size(), 1);
636 test::ExpectTensorEqual<tstring>(
638 test::AsTensor<tstring>({
"a",
"b",
"c",
"d"}, TensorShape({2, 2})));
641 Status BuildSessionInBatch(std::unique_ptr<TfLiteSession>* sess,
642 bool use_model_batch_size,
643 const string& model_path) {
644 std::string model_bytes;
645 TF_RETURN_IF_ERROR(ReadFileToString(
646 Env::Default(), test_util::TestSrcDirPath(model_path), &model_bytes));
647 auto model = tflite::FlatBufferModel::BuildFromModel(
648 flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
649 const int model_batch_size = 5;
650 if (use_model_batch_size) {
651 const tflite::Model* tflite_model = model->GetModel();
652 auto mutable_model = absl::make_unique<tflite::ModelT>();
653 tflite_model->UnPackTo(mutable_model.get(),
nullptr);
655 if (mutable_model->subgraphs.size() != 1) {
657 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
658 strings::StrCat(
"Model subgraph size ",
659 mutable_model->subgraphs.size(),
" not equal to 1"));
661 auto* subgraph = mutable_model->subgraphs[0].get();
662 if (subgraph->inputs.size() != 1) {
664 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
665 strings::StrCat(
"Model subgraph input size ",
666 mutable_model->subgraphs.size(),
" not equal to 1"));
668 auto* tensor = subgraph->tensors[subgraph->inputs[0]].get();
669 if (tensor->shape[0] != 1) {
671 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
672 strings::StrCat(
"Model subgraph input shape[0] ",
673 mutable_model->subgraphs.size(),
" not equal to 1"));
675 tensor->shape[0] = model_batch_size;
676 flatbuffers::FlatBufferBuilder builder;
677 auto packed_model = tflite::Model::Pack(builder, mutable_model.get());
678 FinishModelBuffer(builder, packed_model);
680 reinterpret_cast<const char*
>(builder.GetBufferPointer()),
683 auto model_signature_def_map = GetTestSignatureDefMap();
684 ::google::protobuf::Map<string, SignatureDef> signatures;
685 tensorflow::SessionOptions options;
686 const int num_tflite_interpreters = 4;
688 TF_RETURN_IF_ERROR(TfLiteSession::Create(std::move(model_bytes), options, 1,
689 num_tflite_interpreters, sess,
691 auto scheduler_options = (*sess)->GetSchedulerOptions();
692 const int expected_batch_size = use_model_batch_size
694 : kBatchSize / num_tflite_interpreters;
695 if (scheduler_options.max_execution_batch_size != expected_batch_size) {
697 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
698 strings::StrCat(
"Scheulder max_execution_batch_size ",
699 scheduler_options.max_execution_batch_size,
700 " not equal to expected ", expected_batch_size));
705 using TfLiteSessionBatchSizeTest = ::testing::TestWithParam<bool>;
706 TEST_P(TfLiteSessionBatchSizeTest, TestBatchParallelismForFloat) {
707 std::unique_ptr<TfLiteSession> sess;
708 TF_ASSERT_OK(BuildSessionInBatch(&sess, GetParam(), kTestModel));
709 std::vector<float> example_list;
710 std::vector<float> expected;
711 std::vector<tstring> expected_bytes;
712 std::vector<Tensor> outputs;
714 std::mt19937 random_engine;
715 auto random_func = [&]() {
716 return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
718 for (
int i = 0; i < kBatchSize; i++) {
719 example_list.push_back(random_func());
722 Tensor example_list_tensor =
723 test::AsTensor<float>(example_list, TensorShape({kBatchSize, 1}));
724 TF_EXPECT_OK(sess->Run({{
"x", example_list_tensor}}, {
"y"}, {}, &outputs));
725 EXPECT_TRUE(outputs[0].shape().IsSameSize(TensorShape({kBatchSize, 1})));
728 TEST_P(TfLiteSessionBatchSizeTest, TestBatchParallelismForString) {
729 std::unique_ptr<TfLiteSession> sess;
730 TF_ASSERT_OK(BuildSessionInBatch(&sess, GetParam(), kParseExampleModel));
731 const float default_value = 0;
732 std::vector<tstring> example_list;
733 std::vector<float> expected;
734 std::vector<tstring> expected_bytes;
735 std::vector<Tensor> outputs;
737 std::mt19937 random_engine;
738 auto random_func = [&]() {
739 return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
741 const std::string kTestString =
"test string";
742 const std::string kDefaultString =
"missing";
743 for (
int i = 0; i < kBatchSize; i++) {
744 float val = random_func();
745 tensorflow::Example example;
748 expected.push_back(default_value);
749 expected_bytes.push_back(kDefaultString);
751 expected.push_back(val);
752 expected_bytes.push_back(kTestString);
753 auto* features = example.mutable_features();
754 (*features->mutable_feature())[
"x"].mutable_float_list()->add_value(val);
755 (*features->mutable_feature())[
"y"].mutable_bytes_list()->add_value(
758 example.SerializeToString(&str);
759 example_list.push_back(str);
761 Tensor example_list_tensor =
762 test::AsTensor<tstring>(example_list, TensorShape({kBatchSize}));
763 TF_EXPECT_OK(sess->Run(
764 {{
"input", example_list_tensor}},
765 {
"ParseExample/ParseExampleV2",
"ParseExample/ParseExampleV2:1"}, {},
767 test::ExpectTensorEqual<float>(
769 test::AsTensor<float>(expected, TensorShape({kBatchSize, 1})));
770 EXPECT_EQ(outputs.size(), 2);
771 test::ExpectTensorEqual<tstring>(
773 test::AsTensor<tstring>(expected_bytes, TensorShape({kBatchSize, 1})));
776 INSTANTIATE_TEST_SUITE_P(TfLiteSessionBatchSizeTests,
777 TfLiteSessionBatchSizeTest, ::testing::Bool());
779 TEST(TfLiteSession, TestSetScheduler) {
780 std::string model_bytes;
781 TF_ASSERT_OK(ReadFileToString(Env::Default(),
782 test_util::TestSrcDirPath(kParseExampleModel),
784 auto model = tflite::FlatBufferModel::BuildFromModel(
785 flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
786 auto model_signature_def_map = GetTestSignatureDefMap();
787 ::google::protobuf::Map<string, SignatureDef> signatures;
788 std::unique_ptr<TfLiteSession> sess;
789 tensorflow::SessionOptions options;
791 int split_called = 0;
792 auto TestSplitTfLiteInputTask =
794 std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
795 int open_batch_remaining_slot,
int max_batch_size,
796 std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks) {
798 auto status = TfLiteSession::SplitTfLiteInputTask(
799 input_task_ptr, open_batch_remaining_slot, max_batch_size,
804 BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options;
805 scheduler_options.num_batch_threads = 1;
806 scheduler_options.max_batch_size = internal::kInitialBatchSize;
807 scheduler_options.enable_large_batch_splitting =
true;
808 scheduler_options.max_execution_batch_size = 130;
809 scheduler_options.max_enqueued_batches = 4;
810 scheduler_options.split_input_task_func = TestSplitTfLiteInputTask;
812 TF_ASSERT_OK(TfLiteSession::Create(std::move(model_bytes), options, 1, 1,
813 &sess, &signatures));
815 TF_ASSERT_OK(sess->SetScheduler(
816 TfLiteSession::CreateDefaultBasicBatchScheduler, scheduler_options));
818 const int batch_size = 500;
819 const float default_value = 0;
820 std::vector<tstring> example_list;
821 std::vector<float> expected;
822 std::vector<tstring> expected_bytes;
823 std::vector<Tensor> outputs;
825 std::mt19937 random_engine;
826 auto random_func = [&]() {
827 return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
829 const std::string kTestString =
"test string";
830 const std::string kDefaultString =
"missing";
831 for (
int i = 0; i < batch_size; i++) {
832 float val = random_func();
833 tensorflow::Example example;
836 expected.push_back(default_value);
837 expected_bytes.push_back(kDefaultString);
839 expected.push_back(val);
840 expected_bytes.push_back(kTestString);
841 auto* features = example.mutable_features();
842 (*features->mutable_feature())[
"x"].mutable_float_list()->add_value(val);
843 (*features->mutable_feature())[
"y"].mutable_bytes_list()->add_value(
846 example.SerializeToString(&str);
847 example_list.push_back(str);
850 Tensor example_list_tensor =
851 test::AsTensor<tstring>(example_list, TensorShape({batch_size}));
852 TF_EXPECT_OK(sess->Run(
853 {{
"input", example_list_tensor}},
854 {
"ParseExample/ParseExampleV2",
"ParseExample/ParseExampleV2:1"}, {},
856 test::ExpectTensorEqual<float>(
858 test::AsTensor<float>(expected, TensorShape({batch_size, 1})));
859 EXPECT_EQ(outputs.size(), 2);
860 test::ExpectTensorEqual<tstring>(
862 test::AsTensor<tstring>(expected_bytes, TensorShape({batch_size, 1})));
863 EXPECT_EQ(split_called, 1);
866 #ifdef PLATFORM_GOOGLE
870 static void BM_Reshape(benchmark::State& state,
bool use_flex_op) {
871 static TfLiteSession* session;
872 if (state.thread_index() == 0) {
873 auto model_signature_def_map = GetTestSignatureDefMap();
874 string model_bytes = BuildTestModel(tflite::TensorType_INT32, use_flex_op,
875 &model_signature_def_map);
876 ::google::protobuf::Map<string, SignatureDef> signatures;
877 std::unique_ptr<TfLiteSession> sess;
878 tensorflow::SessionOptions options;
879 TF_ASSERT_OK(TfLiteSession::Create(
880 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
881 absl::GetFlag(FLAGS_num_tflite_interpreters), &sess, &signatures));
882 session = sess.release();
884 Tensor input = test::AsTensor<int32>({1, 2, 3, 4, 5, 6}, TensorShape({6}));
885 Tensor input_shape = test::AsTensor<int32>({3, 2}, TensorShape({2}));
886 std::vector<Tensor> outputs;
887 for (
auto _ : state) {
889 TF_ASSERT_OK(session->Run(
890 {{kTestModelInputList, input}, {kTestModelInputShape, input_shape}},
891 {kTestModelOutput}, {}, &outputs));
895 static void BM_Reshape_Builtin(benchmark::State& state) {
896 BM_Reshape(state,
false);
898 BENCHMARK(BM_Reshape_Builtin)->UseRealTime()->ThreadRange(1, 64);
900 static void BM_Reshape_Flex(benchmark::State& state) {
901 BM_Reshape(state,
true);
903 BENCHMARK(BM_Reshape_Flex)->UseRealTime()->ThreadRange(1, 64);
905 void BM_HalfPlusTwo(benchmark::State& state) {
906 static TfLiteSession* session;
907 if (state.thread_index() == 0) {
909 TF_ASSERT_OK(ReadFileToString(
910 Env::Default(), test_util::TestSrcDirPath(kTestModel), &model_bytes));
911 ::google::protobuf::Map<string, SignatureDef> signatures;
912 std::unique_ptr<TfLiteSession> sess;
913 tensorflow::SessionOptions options;
914 TF_ASSERT_OK(TfLiteSession::Create(
915 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
916 absl::GetFlag(FLAGS_num_tflite_interpreters), &sess, &signatures));
917 session = sess.release();
919 Tensor input = test::AsTensor<float>({1.0, 2.0, 3.0}, TensorShape({3}));
920 std::vector<Tensor> outputs;
921 for (
auto _ : state) {
923 TF_ASSERT_OK(session->Run({{
"x", input}}, {
"y"}, {}, &outputs));
926 BENCHMARK(BM_HalfPlusTwo)->UseRealTime()->ThreadRange(1, 64);
928 void BM_MobileNet(benchmark::State& state) {
929 static TfLiteSession* session;
930 if (state.thread_index() == 0) {
932 TF_ASSERT_OK(ReadFileToString(Env::Default(),
933 test_util::TestSrcDirPath(kMobileNetModel),
935 ::google::protobuf::Map<string, SignatureDef> signatures;
936 std::unique_ptr<TfLiteSession> sess;
937 tensorflow::SessionOptions options;
938 TF_ASSERT_OK(TfLiteSession::Create(
939 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
940 absl::GetFlag(FLAGS_num_tflite_interpreters), &sess, &signatures));
941 session = sess.release();
943 std::vector<uint8> x_data(1 * 224 * 224 * 3, 1);
944 Tensor x = test::AsTensor<uint8>(x_data, TensorShape({1, 224, 224, 3}));
945 std::vector<Tensor> outputs;
946 for (
auto _ : state) {
948 TF_ASSERT_OK(session->Run(
949 {{
"input", x}}, {
"MobilenetV1/Predictions/Reshape_1"}, {}, &outputs));
952 BENCHMARK(BM_MobileNet)->UseRealTime()->ThreadRange(1, 64);
954 void BM_ParseExample(benchmark::State& state) {
955 static TfLiteSession* session;
956 if (state.thread_index() == 0) {
958 TF_ASSERT_OK(ReadFileToString(Env::Default(),
959 test_util::TestSrcDirPath(kParseExampleModel),
961 ::google::protobuf::Map<string, SignatureDef> signatures;
962 std::unique_ptr<TfLiteSession> sess;
963 tensorflow::SessionOptions options;
964 TF_ASSERT_OK(TfLiteSession::Create(
965 std::move(model_bytes), options, absl::GetFlag(FLAGS_num_pools),
966 absl::GetFlag(FLAGS_num_tflite_interpreters), &sess, &signatures));
967 session = sess.release();
969 const int kBatchSize = 500;
970 std::vector<tstring> example_list;
971 std::mt19937 random_engine;
972 auto random_func = [&]() {
973 return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
975 for (
int i = 0; i < kBatchSize; i++) {
976 float val = random_func();
977 tensorflow::Example example;
979 auto* features = example.mutable_features();
980 (*features->mutable_feature())[
"x"].mutable_float_list()->add_value(val);
981 (*features->mutable_feature())[
"y"].mutable_bytes_list()->add_value(
"Test");
982 example.SerializeToString(&str);
983 example_list.push_back(str);
987 test::AsTensor<tstring>(example_list, TensorShape({kBatchSize}));
988 std::vector<Tensor> outputs;
989 for (
auto _ : state) {
991 TF_ASSERT_OK(session->Run(
992 {{
"input", input_batch}},
993 {
"ParseExample/ParseExampleV2",
"ParseExample/ParseExampleV2:1"}, {},
997 BENCHMARK(BM_ParseExample)->UseRealTime()->ThreadRange(1, 64);