TensorFlow Serving C++ API Documentation
tflite_interpreter_pool_test.cc
1 /* Copyright 2021 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/example/example.pb.h"
25 #include "tensorflow/core/example/feature.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow_serving/test_util/test_util.h"
30 
31 namespace tensorflow {
32 namespace serving {
33 namespace internal {
34 
35 using tensorflow::gtl::ArraySlice;
36 
37 constexpr char kParseExampleModel[] =
38  "/servables/tensorflow/testdata/parse_example_tflite/00000123/"
39  "model.tflite";
40 
41 constexpr char kMobileNetModel[] =
42  "/servables/tensorflow/testdata/mobilenet_v1_quant_tflite/00000123/"
43  "model.tflite";
44 
45 TEST(TfLiteInterpreterPool, CreateTfLiteInterpreterPoolTest) {
46  string model_bytes;
47  TF_ASSERT_OK(ReadFileToString(Env::Default(),
48  test_util::TestSrcDirPath(kParseExampleModel),
49  &model_bytes));
50  auto model = tflite::FlatBufferModel::BuildFromModel(
51  flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
52  int pool_size = 1;
53  const tensorflow::SessionOptions options;
54  std::unique_ptr<TfLiteInterpreterPool> interpreter_pool;
55  TF_ASSERT_OK(TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
56  model.get(), options, pool_size, interpreter_pool));
57 
58  auto interpreter = interpreter_pool->GetInterpreter();
59  interpreter_pool->ReturnInterpreter(std::move(interpreter));
60  interpreter_pool.reset();
61 
62  pool_size = 2;
63  TF_ASSERT_OK(TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
64  model.get(), options, pool_size, interpreter_pool));
65  interpreter = interpreter_pool->GetInterpreter();
66  interpreter_pool->ReturnInterpreter(std::move(interpreter));
67  auto next_interpreter = interpreter_pool->GetInterpreter();
68  interpreter_pool->ReturnInterpreter(std::move(next_interpreter));
69  interpreter = interpreter_pool->GetInterpreter();
70  next_interpreter = interpreter_pool->GetInterpreter();
71  interpreter_pool->ReturnInterpreter(std::move(interpreter));
72  interpreter_pool->ReturnInterpreter(std::move(next_interpreter));
73  interpreter_pool.reset();
74 }
75 
76 int GetTensorSize(const TfLiteTensor* tflite_tensor) {
77  int size = 1;
78  for (int i = 0; i < tflite_tensor->dims->size; ++i) {
79  size *= tflite_tensor->dims->data[i];
80  }
81  return size;
82 }
83 
84 template <typename T>
85 std::vector<T> ExtractVector(const TfLiteTensor* tflite_tensor) {
86  const T* v = reinterpret_cast<T*>(tflite_tensor->data.raw);
87  return std::vector<T>(v, v + GetTensorSize(tflite_tensor));
88 }
89 
90 template <>
91 std::vector<std::string> ExtractVector(const TfLiteTensor* tflite_tensor) {
92  std::vector<std::string> out;
93  for (int i = 0; i < tflite::GetStringCount(tflite_tensor); ++i) {
94  auto ref = tflite::GetString(tflite_tensor, i);
95  out.emplace_back(ref.str, ref.len);
96  }
97  return out;
98 }
99 
100 TEST(TfLiteInterpreterWrapper, TfLiteInterpreterWrapperTest) {
101  string model_bytes;
102  TF_ASSERT_OK(ReadFileToString(Env::Default(),
103  test_util::TestSrcDirPath(kParseExampleModel),
104  &model_bytes));
105  auto model = tflite::FlatBufferModel::BuildFromModel(
106  flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
107  tflite::ops::builtin::BuiltinOpResolver resolver;
108  tflite::ops::custom::AddParseExampleOp(&resolver);
109  std::unique_ptr<tflite::Interpreter> interpreter;
110  ASSERT_EQ(tflite::InterpreterBuilder(*model, resolver)(&interpreter,
111  /*num_threads=*/1),
112  kTfLiteOk);
113  ASSERT_EQ(interpreter->inputs().size(), 1);
114  const int idx = interpreter->inputs()[0];
115  auto* tensor = interpreter->tensor(idx);
116  ASSERT_EQ(tensor->type, kTfLiteString);
117  int fixed_batch_size = 10;
118  int actual_batch_size = 3;
119  interpreter->ResizeInputTensor(idx, {fixed_batch_size});
120  ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
121 
122  auto interpreter_wrapper =
123  std::make_unique<TfLiteInterpreterWrapper>(std::move(interpreter));
124  interpreter_wrapper->SetBatchSize(fixed_batch_size);
125  ASSERT_EQ(interpreter_wrapper->GetBatchSize(), fixed_batch_size);
126  std::vector<const Tensor*> data;
127  std::vector<float> expected_floats;
128  std::vector<std::string> expected_strs;
129  std::vector<std::string> expected_input_strs;
130  TensorShape shape;
131  shape.AddDim(actual_batch_size);
132  Tensor t(DT_STRING, shape);
133  for (int i = 0; i < actual_batch_size; ++i) {
134  tensorflow::Example example;
135  std::string str;
136  auto* features = example.mutable_features();
137  const float f = i % 2 == 1 ? 1.0 : -1.0;
138  const std::string s = i % 2 == 1 ? "test" : "missing";
139  expected_floats.push_back(f);
140  expected_strs.push_back(s);
141  (*features->mutable_feature())["x"].mutable_float_list()->add_value(
142  expected_floats.back());
143  (*features->mutable_feature())["y"].mutable_bytes_list()->add_value(
144  expected_strs.back());
145  example.SerializeToString(&str);
146  t.flat<tstring>()(i) = str;
147  expected_input_strs.push_back(str);
148  }
149  data.push_back(&t);
150  ASSERT_FALSE(interpreter_wrapper->SetStringData(
151  data, tensor, -1, actual_batch_size) == absl::OkStatus());
152  TF_ASSERT_OK(
153  interpreter_wrapper->SetStringData(data, tensor, idx, actual_batch_size));
154  auto wrapped = interpreter_wrapper->Get();
155  ASSERT_EQ(wrapped->inputs().size(), 1);
156  int input_idx = wrapped->inputs()[0];
157  auto tflite_input_tensor = wrapped->tensor(input_idx);
158  ASSERT_EQ(GetTensorSize(tflite_input_tensor), fixed_batch_size);
159  ASSERT_EQ(tflite::GetStringCount(tflite_input_tensor), actual_batch_size);
160  auto input_strs = ExtractVector<std::string>(tflite_input_tensor);
161  EXPECT_THAT(input_strs, ::testing::ElementsAreArray(expected_input_strs));
162  ASSERT_EQ(interpreter_wrapper->Invoke(), kTfLiteOk);
163  const std::vector<int>& indices = wrapped->outputs();
164  auto* tflite_tensor = wrapped->tensor(indices[0]);
165  ASSERT_EQ(tflite_tensor->type, kTfLiteFloat32);
166  ASSERT_EQ(GetTensorSize(tflite_tensor), fixed_batch_size);
167  EXPECT_THAT(
168  absl::Span<const float>(ExtractVector<float>(tflite_tensor).data(),
169  actual_batch_size),
170  ::testing::ElementsAreArray(expected_floats));
171  tflite_tensor = wrapped->tensor(indices[1]);
172  ASSERT_EQ(tflite_tensor->type, kTfLiteString);
173  ASSERT_EQ(GetTensorSize(tflite_tensor), fixed_batch_size);
174  EXPECT_THAT(
175  absl::Span<const std::string>(
176  ExtractVector<std::string>(tflite_tensor).data(), actual_batch_size),
177  ::testing::ElementsAreArray(expected_strs));
178 }
179 
180 } // namespace internal
181 } // namespace serving
182 } // namespace tensorflow