15 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/example/example.pb.h"
25 #include "tensorflow/core/example/feature.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow_serving/test_util/test_util.h"
31 namespace tensorflow {
35 using tensorflow::gtl::ArraySlice;
37 constexpr
char kParseExampleModel[] =
38 "/servables/tensorflow/testdata/parse_example_tflite/00000123/"
41 constexpr
char kMobileNetModel[] =
42 "/servables/tensorflow/testdata/mobilenet_v1_quant_tflite/00000123/"
45 TEST(TfLiteInterpreterPool, CreateTfLiteInterpreterPoolTest) {
47 TF_ASSERT_OK(ReadFileToString(Env::Default(),
48 test_util::TestSrcDirPath(kParseExampleModel),
50 auto model = tflite::FlatBufferModel::BuildFromModel(
51 flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
53 const tensorflow::SessionOptions options;
54 std::unique_ptr<TfLiteInterpreterPool> interpreter_pool;
55 TF_ASSERT_OK(TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
56 model.get(), options, pool_size, interpreter_pool));
58 auto interpreter = interpreter_pool->GetInterpreter();
59 interpreter_pool->ReturnInterpreter(std::move(interpreter));
60 interpreter_pool.reset();
63 TF_ASSERT_OK(TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
64 model.get(), options, pool_size, interpreter_pool));
65 interpreter = interpreter_pool->GetInterpreter();
66 interpreter_pool->ReturnInterpreter(std::move(interpreter));
67 auto next_interpreter = interpreter_pool->GetInterpreter();
68 interpreter_pool->ReturnInterpreter(std::move(next_interpreter));
69 interpreter = interpreter_pool->GetInterpreter();
70 next_interpreter = interpreter_pool->GetInterpreter();
71 interpreter_pool->ReturnInterpreter(std::move(interpreter));
72 interpreter_pool->ReturnInterpreter(std::move(next_interpreter));
73 interpreter_pool.reset();
76 int GetTensorSize(
const TfLiteTensor* tflite_tensor) {
78 for (
int i = 0; i < tflite_tensor->dims->size; ++i) {
79 size *= tflite_tensor->dims->data[i];
85 std::vector<T> ExtractVector(
const TfLiteTensor* tflite_tensor) {
86 const T* v =
reinterpret_cast<T*
>(tflite_tensor->data.raw);
87 return std::vector<T>(v, v + GetTensorSize(tflite_tensor));
91 std::vector<std::string> ExtractVector(
const TfLiteTensor* tflite_tensor) {
92 std::vector<std::string> out;
93 for (
int i = 0; i < tflite::GetStringCount(tflite_tensor); ++i) {
94 auto ref = tflite::GetString(tflite_tensor, i);
95 out.emplace_back(ref.str, ref.len);
100 TEST(TfLiteInterpreterWrapper, TfLiteInterpreterWrapperTest) {
102 TF_ASSERT_OK(ReadFileToString(Env::Default(),
103 test_util::TestSrcDirPath(kParseExampleModel),
105 auto model = tflite::FlatBufferModel::BuildFromModel(
106 flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
107 tflite::ops::builtin::BuiltinOpResolver resolver;
108 tflite::ops::custom::AddParseExampleOp(&resolver);
109 std::unique_ptr<tflite::Interpreter> interpreter;
110 ASSERT_EQ(tflite::InterpreterBuilder(*model, resolver)(&interpreter,
113 ASSERT_EQ(interpreter->inputs().size(), 1);
114 const int idx = interpreter->inputs()[0];
115 auto* tensor = interpreter->tensor(idx);
116 ASSERT_EQ(tensor->type, kTfLiteString);
117 int fixed_batch_size = 10;
118 int actual_batch_size = 3;
119 interpreter->ResizeInputTensor(idx, {fixed_batch_size});
120 ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
122 auto interpreter_wrapper =
123 std::make_unique<TfLiteInterpreterWrapper>(std::move(interpreter));
124 interpreter_wrapper->SetBatchSize(fixed_batch_size);
125 ASSERT_EQ(interpreter_wrapper->GetBatchSize(), fixed_batch_size);
126 std::vector<const Tensor*> data;
127 std::vector<float> expected_floats;
128 std::vector<std::string> expected_strs;
129 std::vector<std::string> expected_input_strs;
131 shape.AddDim(actual_batch_size);
132 Tensor t(DT_STRING, shape);
133 for (
int i = 0; i < actual_batch_size; ++i) {
134 tensorflow::Example example;
136 auto* features = example.mutable_features();
137 const float f = i % 2 == 1 ? 1.0 : -1.0;
138 const std::string s = i % 2 == 1 ?
"test" :
"missing";
139 expected_floats.push_back(f);
140 expected_strs.push_back(s);
141 (*features->mutable_feature())[
"x"].mutable_float_list()->add_value(
142 expected_floats.back());
143 (*features->mutable_feature())[
"y"].mutable_bytes_list()->add_value(
144 expected_strs.back());
145 example.SerializeToString(&str);
146 t.flat<tstring>()(i) = str;
147 expected_input_strs.push_back(str);
150 ASSERT_FALSE(interpreter_wrapper->SetStringData(
151 data, tensor, -1, actual_batch_size) == absl::OkStatus());
153 interpreter_wrapper->SetStringData(data, tensor, idx, actual_batch_size));
154 auto wrapped = interpreter_wrapper->Get();
155 ASSERT_EQ(wrapped->inputs().size(), 1);
156 int input_idx = wrapped->inputs()[0];
157 auto tflite_input_tensor = wrapped->tensor(input_idx);
158 ASSERT_EQ(GetTensorSize(tflite_input_tensor), fixed_batch_size);
159 ASSERT_EQ(tflite::GetStringCount(tflite_input_tensor), actual_batch_size);
160 auto input_strs = ExtractVector<std::string>(tflite_input_tensor);
161 EXPECT_THAT(input_strs, ::testing::ElementsAreArray(expected_input_strs));
162 ASSERT_EQ(interpreter_wrapper->Invoke(), kTfLiteOk);
163 const std::vector<int>& indices = wrapped->outputs();
164 auto* tflite_tensor = wrapped->tensor(indices[0]);
165 ASSERT_EQ(tflite_tensor->type, kTfLiteFloat32);
166 ASSERT_EQ(GetTensorSize(tflite_tensor), fixed_batch_size);
168 absl::Span<const float>(ExtractVector<float>(tflite_tensor).data(),
170 ::testing::ElementsAreArray(expected_floats));
171 tflite_tensor = wrapped->tensor(indices[1]);
172 ASSERT_EQ(tflite_tensor->type, kTfLiteString);
173 ASSERT_EQ(GetTensorSize(tflite_tensor), fixed_batch_size);
175 absl::Span<const std::string>(
176 ExtractVector<std::string>(tflite_tensor).data(), actual_batch_size),
177 ::testing::ElementsAreArray(expected_strs));