TensorFlow Serving C++ API Documentation
multi_inference_helper_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/multi_inference_helper.h"
17 
18 #include <memory>
19 #include <type_traits>
20 #include <utility>
21 #include <vector>
22 
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/example/feature.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow_serving/apis/classification.pb.h"
31 #include "tensorflow_serving/apis/input.pb.h"
32 #include "tensorflow_serving/apis/regression.pb.h"
33 #include "tensorflow_serving/core/availability_preserving_policy.h"
34 #include "tensorflow_serving/model_servers/model_platform_types.h"
35 #include "tensorflow_serving/model_servers/platform_config_util.h"
36 #include "tensorflow_serving/model_servers/server_core.h"
37 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/util.h"
39 #include "tensorflow_serving/test_util/test_util.h"
40 
41 namespace tensorflow {
42 namespace serving {
43 namespace {
44 
45 constexpr char kTestModelName[] = "test_model";
46 constexpr int kTestModelVersion = 123;
47 
48 // Test fixture for MultiInferenceTest related tests sets up a ServerCore
49 // pointing to TF1 or TF2 version of half_plus_two SavedModel (based on `T`).
50 typedef std::integral_constant<int, 1> tf1_model_t;
51 typedef std::integral_constant<int, 2> tf2_model_t;
52 
53 template <typename T>
54 class MultiInferenceTest : public ::testing::Test {
55  public:
56  static void SetUpTestSuite() {
57  SetSignatureMethodNameCheckFeature(UseTf1Model());
58  TF_ASSERT_OK(CreateServerCore(&server_core_));
59  }
60 
61  static void TearDownTestSuite() { server_core_.reset(); }
62 
63  protected:
64  static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
65  ModelServerConfig config;
66  auto model_config = config.mutable_model_config_list()->add_config();
67  model_config->set_name(kTestModelName);
68  const auto& tf1_saved_model = test_util::TensorflowTestSrcDirPath(
69  "cc/saved_model/testdata/half_plus_two");
70  const auto& tf2_saved_model = test_util::TestSrcDirPath(
71  "/servables/tensorflow/testdata/saved_model_half_plus_two_tf2_cpu");
72  model_config->set_base_path(UseTf1Model() ? tf1_saved_model
73  : tf2_saved_model);
74  model_config->set_model_platform(kTensorFlowModelPlatform);
75 
76  // For ServerCore Options, we leave servable_state_monitor_creator
77  // unspecified so the default servable_state_monitor_creator will be used.
78  ServerCore::Options options;
79  options.model_server_config = config;
80  options.platform_config_map =
81  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
82  // Reduce the number of initial load threads to be num_load_threads to avoid
83  // timing out in tests.
84  options.num_initial_load_threads = options.num_load_threads;
85  options.aspired_version_policy =
86  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
87  return ServerCore::Create(std::move(options), server_core);
88  }
89 
90  static bool UseTf1Model() { return std::is_same<T, tf1_model_t>::value; }
91 
92  ServerCore* GetServerCore() { return this->server_core_.get(); }
93 
94  const int64_t servable_version_ = kTestModelVersion;
95 
96  private:
97  static std::unique_ptr<ServerCore> server_core_;
98 };
99 
100 template <typename T>
101 std::unique_ptr<ServerCore> MultiInferenceTest<T>::server_core_;
102 
103 TYPED_TEST_SUITE_P(MultiInferenceTest);
104 
106 // Test Helpers
107 
108 void AddInput(const std::vector<std::pair<string, float>>& feature_kv,
109  MultiInferenceRequest* request) {
110  auto* example =
111  request->mutable_input()->mutable_example_list()->add_examples();
112  auto* features = example->mutable_features()->mutable_feature();
113  for (const auto& feature : feature_kv) {
114  (*features)[feature.first].mutable_float_list()->add_value(feature.second);
115  }
116 }
117 
118 void PopulateTask(const string& signature_name, const string& method_name,
119  int64_t version, InferenceTask* task) {
120  ModelSpec model_spec;
121  model_spec.set_name(kTestModelName);
122  if (version > 0) {
123  model_spec.mutable_version()->set_value(version);
124  }
125  model_spec.set_signature_name(signature_name);
126  *task->mutable_model_spec() = model_spec;
127  task->set_method_name(method_name);
128 }
129 
130 void ExpectStatusError(const Status& status,
131  const tensorflow::errors::Code expected_code,
132  const string& message_substring) {
133  ASSERT_EQ(expected_code, status.code());
134  EXPECT_THAT(status.message(), ::testing::HasSubstr(message_substring));
135 }
136 
138 // Tests
139 
140 TYPED_TEST_P(MultiInferenceTest, MissingInputTest) {
141  MultiInferenceRequest request;
142  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
143 
144  MultiInferenceResponse response;
145  ExpectStatusError(
146  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
147  thread::ThreadPoolOptions(), request,
148  &response),
149  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
150  "Input is empty");
151 }
152 
153 TYPED_TEST_P(MultiInferenceTest, UndefinedSignatureTest) {
154  MultiInferenceRequest request;
155  AddInput({{"x", 2}}, &request);
156  PopulateTask("ThisSignatureDoesNotExist", kRegressMethodName, -1,
157  request.add_tasks());
158 
159  MultiInferenceResponse response;
160  ExpectStatusError(
161  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
162  thread::ThreadPoolOptions(), request,
163  &response),
164  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
165  "signature not found");
166 }
167 
168 // Two ModelSpecs, accessing different models.
169 TYPED_TEST_P(MultiInferenceTest, InconsistentModelSpecsInRequestTest) {
170  MultiInferenceRequest request;
171  AddInput({{"x", 2}}, &request);
172  // Valid signature.
173  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
174 
175  // Add invalid Task to request.
176  ModelSpec model_spec;
177  model_spec.set_name("ModelDoesNotExist");
178  model_spec.set_signature_name("regress_x_to_y");
179  auto* task = request.add_tasks();
180  *task->mutable_model_spec() = model_spec;
181  task->set_method_name(kRegressMethodName);
182 
183  MultiInferenceResponse response;
184  ExpectStatusError(
185  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
186  thread::ThreadPoolOptions(), request,
187  &response),
188  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
189  "must access the same model name");
190 }
191 
192 TYPED_TEST_P(MultiInferenceTest, EvaluateDuplicateSignaturesTest) {
193  MultiInferenceRequest request;
194  AddInput({{"x", 2}}, &request);
195  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
196  // Add the same task again (error).
197  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
198 
199  MultiInferenceResponse response;
200  ExpectStatusError(
201  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
202  thread::ThreadPoolOptions(), request,
203  &response),
204  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
205  "Duplicate evaluation of signature: regress_x_to_y");
206 }
207 
208 TYPED_TEST_P(MultiInferenceTest, UsupportedSignatureTypeTest) {
209  MultiInferenceRequest request;
210  AddInput({{"x", 2}}, &request);
211  PopulateTask("serving_default", kPredictMethodName, -1, request.add_tasks());
212 
213  MultiInferenceResponse response;
214  ExpectStatusError(
215  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
216  thread::ThreadPoolOptions(), request,
217  &response),
218  static_cast<absl::StatusCode>(absl::StatusCode::kUnimplemented),
219  "Unsupported signature");
220 }
221 
222 TYPED_TEST_P(MultiInferenceTest, ValidSingleSignatureTest) {
223  MultiInferenceRequest request;
224  AddInput({{"x", 2}}, &request);
225  PopulateTask("regress_x_to_y", kRegressMethodName, this->servable_version_,
226  request.add_tasks());
227 
228  MultiInferenceResponse expected_response;
229  auto* inference_result = expected_response.add_results();
230  auto* model_spec = inference_result->mutable_model_spec();
231  *model_spec = request.tasks(0).model_spec();
232  model_spec->mutable_version()->set_value(this->servable_version_);
233  auto* regression_result = inference_result->mutable_regression_result();
234  regression_result->add_regressions()->set_value(3.0);
235 
236  MultiInferenceResponse response;
237  TF_ASSERT_OK(RunMultiInferenceWithServerCore(
238  RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
239  &response));
240  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
241 }
242 
243 TYPED_TEST_P(MultiInferenceTest, MultipleValidRegressSignaturesTest) {
244  MultiInferenceRequest request;
245  AddInput({{"x", 2}}, &request);
246  PopulateTask("regress_x_to_y", kRegressMethodName, this->servable_version_,
247  request.add_tasks());
248  PopulateTask("regress_x_to_y2", kRegressMethodName, this->servable_version_,
249  request.add_tasks());
250 
251  MultiInferenceResponse expected_response;
252 
253  // regress_x_to_y is y = 0.5x + 2.
254  auto* inference_result_1 = expected_response.add_results();
255  auto* model_spec_1 = inference_result_1->mutable_model_spec();
256  *model_spec_1 = request.tasks(0).model_spec();
257  model_spec_1->mutable_version()->set_value(this->servable_version_);
258  auto* regression_result_1 = inference_result_1->mutable_regression_result();
259  regression_result_1->add_regressions()->set_value(3.0);
260 
261  // regress_x_to_y2 is y2 = 0.5x + 3.
262  auto* inference_result_2 = expected_response.add_results();
263  auto* model_spec_2 = inference_result_2->mutable_model_spec();
264  *model_spec_2 = request.tasks(1).model_spec();
265  model_spec_2->mutable_version()->set_value(this->servable_version_);
266  auto* regression_result_2 = inference_result_2->mutable_regression_result();
267  regression_result_2->add_regressions()->set_value(4.0);
268 
269  MultiInferenceResponse response;
270  TF_ASSERT_OK(RunMultiInferenceWithServerCore(
271  RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
272  &response));
273  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
274 }
275 
276 TYPED_TEST_P(MultiInferenceTest, RegressAndClassifySignaturesTest) {
277  MultiInferenceRequest request;
278  AddInput({{"x", 2}}, &request);
279  PopulateTask("regress_x_to_y", kRegressMethodName, this->servable_version_,
280  request.add_tasks());
281  PopulateTask("classify_x_to_y", kClassifyMethodName, this->servable_version_,
282  request.add_tasks());
283 
284  MultiInferenceResponse expected_response;
285  auto* inference_result_1 = expected_response.add_results();
286  auto* model_spec_1 = inference_result_1->mutable_model_spec();
287  *model_spec_1 = request.tasks(0).model_spec();
288  model_spec_1->mutable_version()->set_value(this->servable_version_);
289  auto* regression_result = inference_result_1->mutable_regression_result();
290  regression_result->add_regressions()->set_value(3.0);
291 
292  auto* inference_result_2 = expected_response.add_results();
293  auto* model_spec_2 = inference_result_2->mutable_model_spec();
294  *model_spec_2 = request.tasks(1).model_spec();
295  model_spec_2->mutable_version()->set_value(this->servable_version_);
296  auto* classification_result =
297  inference_result_2->mutable_classification_result();
298  classification_result->add_classifications()->add_classes()->set_score(3.0);
299 
300  MultiInferenceResponse response;
301  TF_ASSERT_OK(RunMultiInferenceWithServerCore(
302  RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
303  &response));
304  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
305 }
306 
307 // Verifies that RunMultiInferenceWithServerCoreWithModelSpec() uses the model
308 // spec override rather than the one in the request.
309 TYPED_TEST_P(MultiInferenceTest, ModelSpecOverride) {
310  MultiInferenceRequest request;
311  AddInput({{"x", 2}}, &request);
312  PopulateTask("regress_x_to_y", kRegressMethodName, this->servable_version_,
313  request.add_tasks());
314  auto model_spec_override =
315  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
316 
317  MultiInferenceResponse response;
318  EXPECT_NE(tensorflow::error::NOT_FOUND,
319  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
320  thread::ThreadPoolOptions(),
321  request, &response)
322  .code());
323  EXPECT_EQ(
324  tensorflow::error::NOT_FOUND,
325  RunMultiInferenceWithServerCoreWithModelSpec(
326  RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(),
327  model_spec_override, request, &response)
328  .code());
329 }
330 
331 TYPED_TEST_P(MultiInferenceTest, ThreadPoolOptions) {
332  MultiInferenceRequest request;
333  AddInput({{"x", 2}}, &request);
334  PopulateTask("regress_x_to_y", kRegressMethodName, this->servable_version_,
335  request.add_tasks());
336 
337  MultiInferenceResponse expected_response;
338  auto* inference_result = expected_response.add_results();
339  auto* model_spec = inference_result->mutable_model_spec();
340  *model_spec = request.tasks(0).model_spec();
341  model_spec->mutable_version()->set_value(this->servable_version_);
342  auto* regression_result = inference_result->mutable_regression_result();
343  regression_result->add_regressions()->set_value(3.0);
344 
345  test_util::CountingThreadPool inter_op_threadpool(Env::Default(), "InterOp",
346  /*num_threads=*/1);
347  test_util::CountingThreadPool intra_op_threadpool(Env::Default(), "IntraOp",
348  /*num_threads=*/1);
349  thread::ThreadPoolOptions thread_pool_options;
350  thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
351  thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
352  MultiInferenceResponse response;
353  TF_ASSERT_OK(
354  RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
355  thread_pool_options, request, &response));
356  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
357 
358  // The intra_op_threadpool doesn't have anything scheduled.
359  ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
360 }
361 
362 REGISTER_TYPED_TEST_SUITE_P(
363  MultiInferenceTest, MissingInputTest, UndefinedSignatureTest,
364  InconsistentModelSpecsInRequestTest, EvaluateDuplicateSignaturesTest,
365  UsupportedSignatureTypeTest, ValidSingleSignatureTest,
366  MultipleValidRegressSignaturesTest, RegressAndClassifySignaturesTest,
367  ModelSpecOverride, ThreadPoolOptions);
368 
369 typedef ::testing::Types<tf1_model_t, tf2_model_t> ModelTypes;
370 INSTANTIATE_TYPED_TEST_SUITE_P(MultiInference, MultiInferenceTest, ModelTypes);
371 
372 } // namespace
373 } // namespace serving
374 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231