TensorFlow Serving C++ API Documentation
predict_impl_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/cc/saved_model/signature_constants.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow_serving/core/availability_preserving_policy.h"
26 #include "tensorflow_serving/model_servers/model_platform_types.h"
27 #include "tensorflow_serving/model_servers/platform_config_util.h"
28 #include "tensorflow_serving/model_servers/server_core.h"
29 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
30 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
31 #include "tensorflow_serving/servables/tensorflow/test_util/fake_thread_pool_factory.h"
32 #include "tensorflow_serving/servables/tensorflow/test_util/fake_thread_pool_factory.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
34 #include "tensorflow_serving/util/oss_or_google.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 namespace {
39 
40 constexpr char kTestModelName[] = "test_model";
41 constexpr int kTestModelVersion = 123;
42 
43 const char kInputTensorKey[] = "x";
44 const char kOutputTensorKey[] = "y";
45 
46 class PredictImplTest : public ::testing::Test {
47  public:
48  static void SetUpTestSuite() {
49  TF_ASSERT_OK(CreateServerCore(test_util::TensorflowTestSrcDirPath(
50  "cc/saved_model/testdata/half_plus_two"),
51  &saved_model_server_core_));
52  TF_ASSERT_OK(CreateServerCore(
53  test_util::TestSrcDirPath(
54  "/servables/tensorflow/testdata/saved_model_counter"),
55  &saved_model_server_core_counter_model_));
56  }
57 
58  static void TearDownTestSuite() {
59  server_core_.reset();
60  server_core_bad_model_.reset();
61  saved_model_server_core_.reset();
62  saved_model_server_core_counter_model_.reset();
63  }
64 
65  protected:
66  static Status CreateServerCore(const string& model_path,
67  std::unique_ptr<ServerCore>* server_core) {
68  ModelServerConfig config;
69  auto model_config = config.mutable_model_config_list()->add_config();
70  model_config->set_name(kTestModelName);
71  model_config->set_base_path(model_path);
72  model_config->set_model_platform(kTensorFlowModelPlatform);
73 
74  // For ServerCore Options, we leave servable_state_monitor_creator
75  // unspecified so the default servable_state_monitor_creator will be used.
76  ServerCore::Options options;
77  options.model_server_config = config;
78  options.platform_config_map =
79  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
80  options.aspired_version_policy =
81  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
82  // Reduce the number of initial load threads to be num_load_threads to avoid
83  // timing out in tests.
84  options.num_initial_load_threads = options.num_load_threads;
85  return ServerCore::Create(std::move(options), server_core);
86  }
87 
88  ServerCore* GetServerCore() { return saved_model_server_core_.get(); }
89 
90  ServerCore* GetServerCoreWithCounterModel() {
91  return saved_model_server_core_counter_model_.get();
92  }
93 
94  RunOptions GetRunOptions() { return RunOptions(); }
95 
96  private:
97  static std::unique_ptr<ServerCore> server_core_;
98  static std::unique_ptr<ServerCore> server_core_bad_model_;
99  static std::unique_ptr<ServerCore> saved_model_server_core_;
100  static std::unique_ptr<ServerCore> saved_model_server_core_counter_model_;
101 };
102 
103 std::unique_ptr<ServerCore> PredictImplTest::server_core_;
104 std::unique_ptr<ServerCore> PredictImplTest::server_core_bad_model_;
105 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
106 std::unique_ptr<ServerCore>
107  PredictImplTest::saved_model_server_core_counter_model_;
108 
109 TEST_F(PredictImplTest, MissingOrEmptyModelSpec) {
110  PredictRequest request;
111  PredictResponse response;
112 
113  // Empty request is invalid.
114  TensorflowPredictor predictor;
115  EXPECT_EQ(
116  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
117  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
118  .code());
119 
120  ModelSpec* model_spec = request.mutable_model_spec();
121  model_spec->clear_name();
122 
123  // Model name is not specified.
124  EXPECT_EQ(
125  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
126  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
127  .code());
128 
129  // Model name is wrong, not found.
130  model_spec->set_name("test");
131  EXPECT_EQ(
132  tensorflow::error::NOT_FOUND,
133  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
134  .code());
135 }
136 
137 TEST_F(PredictImplTest, EmptyInputList) {
138  PredictRequest request;
139  PredictResponse response;
140 
141  ModelSpec* model_spec = request.mutable_model_spec();
142  model_spec->set_name(kTestModelName);
143  model_spec->mutable_version()->set_value(kTestModelVersion);
144 
145  TensorflowPredictor predictor;
146  // The input is empty.
147  EXPECT_EQ(
148  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
149  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
150  .code());
151 }
152 
153 TEST_F(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
154  PredictRequest request;
155  PredictResponse response;
156 
157  ModelSpec* model_spec = request.mutable_model_spec();
158  model_spec->set_name(kTestModelName);
159  model_spec->mutable_version()->set_value(kTestModelVersion);
160 
161  TensorProto tensor_proto;
162  tensor_proto.add_string_val("any_key");
163  tensor_proto.set_dtype(tensorflow::DT_STRING);
164  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
165 
166  TensorflowPredictor predictor;
167  auto inputs = request.mutable_inputs();
168  (*inputs)["key"] = tensor_proto;
169  EXPECT_EQ(
170  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
171  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
172  .code());
173 }
174 
175 TEST_F(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
176  PredictRequest request;
177  PredictResponse response;
178 
179  ModelSpec* model_spec = request.mutable_model_spec();
180  model_spec->set_name(kTestModelName);
181  model_spec->mutable_version()->set_value(kTestModelVersion);
182 
183  TensorProto tensor_proto;
184  tensor_proto.add_float_val(2.0);
185  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
186  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
187  request.add_output_filter("output_filter");
188 
189  TensorflowPredictor predictor;
190  // Output filter like this doesn't exist.
191  EXPECT_EQ(
192  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
193  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
194  .code());
195 
196  request.clear_output_filter();
197  request.add_output_filter(kOutputTensorKey);
198  TF_EXPECT_OK(
199  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
200  request.add_output_filter(kOutputTensorKey);
201 
202  // Duplicate output filter specified.
203  EXPECT_EQ(
204  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
205  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
206  .code());
207 }
208 
209 TEST_F(PredictImplTest, InputTensorsHaveWrongType) {
210  PredictRequest request;
211  PredictResponse response;
212 
213  ModelSpec* model_spec = request.mutable_model_spec();
214  model_spec->set_name(kTestModelName);
215  model_spec->mutable_version()->set_value(kTestModelVersion);
216 
217  TensorProto tensor_proto;
218  tensor_proto.add_string_val("any_key");
219  tensor_proto.set_dtype(tensorflow::DT_STRING);
220  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
221  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
222  request.add_output_filter(kOutputTensorKey);
223 
224  TensorflowPredictor predictor;
225  // Input tensors are all wrong.
226  EXPECT_EQ(
227  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
228  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
229  .code());
230 }
231 
232 TEST_F(PredictImplTest, PredictionSuccess) {
233  PredictRequest request;
234  PredictResponse response;
235 
236  ModelSpec* model_spec = request.mutable_model_spec();
237  model_spec->set_name(kTestModelName);
238  model_spec->mutable_version()->set_value(kTestModelVersion);
239 
240  TensorProto tensor_proto;
241  tensor_proto.add_float_val(2.0);
242  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
243  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
244 
245  TensorflowPredictor predictor;
246  TF_EXPECT_OK(
247  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
248  TensorProto output_tensor_proto;
249  output_tensor_proto.add_float_val(3);
250  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
251  output_tensor_proto.mutable_tensor_shape();
252  PredictResponse expected_response;
253  *expected_response.mutable_model_spec() = *model_spec;
254  expected_response.mutable_model_spec()->set_signature_name(
255  kDefaultServingSignatureDefKey);
256  (*expected_response.mutable_outputs())[kOutputTensorKey] =
257  output_tensor_proto;
258  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
259 }
260 
261 // Test querying a model with a named regression signature (not default).
262 TEST_F(PredictImplTest, PredictionWithNamedRegressionSignature) {
263  PredictRequest request;
264  PredictResponse response;
265 
266  ModelSpec* model_spec = request.mutable_model_spec();
267  model_spec->set_name(kTestModelName);
268  model_spec->mutable_version()->set_value(kTestModelVersion);
269  model_spec->set_signature_name("regress_x2_to_y3");
270 
271  TensorProto tensor_proto;
272  tensor_proto.add_float_val(2.0);
273  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
274  (*request.mutable_inputs())[kRegressInputs] = tensor_proto;
275  TensorflowPredictor predictor;
276  TF_ASSERT_OK(
277  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
278  TensorProto output_tensor_proto;
279  output_tensor_proto.add_float_val(4);
280  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
281  output_tensor_proto.mutable_tensor_shape();
282  PredictResponse expected_response;
283  *expected_response.mutable_model_spec() = *model_spec;
284  (*expected_response.mutable_outputs())[kRegressOutputs] = output_tensor_proto;
285  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
286 }
287 
288 // Test querying a model with a classification signature. Predict calls work
289 // with predict, classify, and regress signatures when using SavedModel.
290 TEST_F(PredictImplTest, PredictionWithNamedClassificationSignature) {
291  PredictRequest request;
292  PredictResponse response;
293 
294  ModelSpec* model_spec = request.mutable_model_spec();
295  model_spec->set_name(kTestModelName);
296  model_spec->mutable_version()->set_value(kTestModelVersion);
297  model_spec->set_signature_name("classify_x2_to_y3");
298 
299  TensorProto tensor_proto;
300  tensor_proto.add_float_val(2.0);
301  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
302  (*request.mutable_inputs())[kClassifyInputs] = tensor_proto;
303 
304  TensorflowPredictor predictor;
305  TF_ASSERT_OK(
306  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
307  TensorProto output_tensor_proto;
308  output_tensor_proto.add_float_val(4);
309  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
310  output_tensor_proto.mutable_tensor_shape();
311  PredictResponse expected_response;
312  *expected_response.mutable_model_spec() = *model_spec;
313  (*expected_response.mutable_outputs())[kClassifyOutputScores] =
314  output_tensor_proto;
315  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
316 }
317 
318 // Test querying a counter model with signatures. Predict calls work with
319 // customized signatures. It calls get_counter, incr_counter,
320 // reset_counter, incr_counter, and incr_counter_by(3) in order.
321 //
322 // *Notes*: These signatures are stateful and over-simplied only to demonstrate
323 // Predict calls with only inputs or outputs. State is not supported in
324 // TensorFlow Serving on most scalable or production hosting environments.
325 TEST_F(PredictImplTest, PredictionWithCustomizedSignatures) {
326  PredictRequest request;
327  PredictResponse response;
328  TensorflowPredictor predictor;
329 
330  // Call get_counter. Expected result 0.
331  ModelSpec* model_spec = request.mutable_model_spec();
332  model_spec->set_name(kTestModelName);
333  model_spec->mutable_version()->set_value(kTestModelVersion);
334  model_spec->set_signature_name("get_counter");
335 
336  TF_ASSERT_OK(predictor.Predict(
337  GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
338 
339  PredictResponse expected_get_counter;
340  *expected_get_counter.mutable_model_spec() = *model_spec;
341  TensorProto output_get_counter;
342  output_get_counter.add_float_val(0);
343  output_get_counter.set_dtype(tensorflow::DT_FLOAT);
344  output_get_counter.mutable_tensor_shape();
345  (*expected_get_counter.mutable_outputs())["output"] = output_get_counter;
346  EXPECT_THAT(response, test_util::EqualsProto(expected_get_counter));
347 
348  // Call incr_counter. Expect: 1.
349  model_spec->set_signature_name("incr_counter");
350  TF_ASSERT_OK(predictor.Predict(
351  GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
352 
353  PredictResponse expected_incr_counter;
354  *expected_incr_counter.mutable_model_spec() = *model_spec;
355  TensorProto output_incr_counter;
356  output_incr_counter.add_float_val(1);
357  output_incr_counter.set_dtype(tensorflow::DT_FLOAT);
358  output_incr_counter.mutable_tensor_shape();
359  (*expected_incr_counter.mutable_outputs())["output"] = output_incr_counter;
360  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter));
361 
362  // Call reset_counter. Expect: 0.
363  model_spec->set_signature_name("reset_counter");
364  TF_ASSERT_OK(predictor.Predict(
365  GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
366 
367  PredictResponse expected_reset_counter;
368  *expected_reset_counter.mutable_model_spec() = *model_spec;
369  TensorProto output_reset_counter;
370  output_reset_counter.add_float_val(0);
371  output_reset_counter.set_dtype(tensorflow::DT_FLOAT);
372  output_reset_counter.mutable_tensor_shape();
373  (*expected_reset_counter.mutable_outputs())["output"] = output_reset_counter;
374  EXPECT_THAT(response, test_util::EqualsProto(expected_reset_counter));
375 
376  // Call incr_counter. Expect: 1.
377  model_spec->set_signature_name("incr_counter");
378  request.add_output_filter("output");
379  TF_ASSERT_OK(predictor.Predict(
380  GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
381  request.clear_output_filter();
382 
383  PredictResponse expected_incr_counter2;
384  *expected_incr_counter2.mutable_model_spec() = *model_spec;
385  TensorProto output_incr_counter2;
386  output_incr_counter2.add_float_val(1);
387  output_incr_counter2.set_dtype(tensorflow::DT_FLOAT);
388  output_incr_counter2.mutable_tensor_shape();
389  (*expected_incr_counter2.mutable_outputs())["output"] = output_incr_counter2;
390  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter2));
391 
392  // Call incr_counter_by. Expect: 4.
393  model_spec->set_signature_name("incr_counter_by");
394  TensorProto tensor_proto;
395  tensor_proto.add_float_val(3);
396  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
397  (*request.mutable_inputs())["delta"] = tensor_proto;
398 
399  TF_ASSERT_OK(predictor.Predict(
400  GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
401 
402  PredictResponse expected_incr_counter_by;
403  *expected_incr_counter_by.mutable_model_spec() = *model_spec;
404  TensorProto output_incr_counter_by;
405  output_incr_counter_by.add_float_val(4);
406  output_incr_counter_by.set_dtype(tensorflow::DT_FLOAT);
407  output_incr_counter_by.mutable_tensor_shape();
408  (*expected_incr_counter_by.mutable_outputs())["output"] =
409  output_incr_counter_by;
410  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter_by));
411 }
412 
413 // Verifies that PredictWithModelSpec() uses the model spec override rather than
414 // the one in the request.
415 TEST_F(PredictImplTest, ModelSpecOverride) {
416  auto request = test_util::CreateProto<PredictRequest>(
417  "model_spec {"
418  " name: \"test_model\""
419  "}");
420  auto model_spec_override =
421  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
422 
423  TensorflowPredictor predictor;
424  PredictResponse response;
425  EXPECT_NE(tensorflow::error::NOT_FOUND,
426  predictor.Predict(RunOptions(), GetServerCore(), request, &response)
427  .code());
428  EXPECT_EQ(tensorflow::error::NOT_FOUND,
429  predictor
430  .PredictWithModelSpec(RunOptions(), GetServerCore(),
431  model_spec_override, request, &response)
432  .code());
433 }
434 
435 TEST_F(PredictImplTest, ThreadPoolFactory) {
436  PredictRequest request;
437  PredictResponse response;
438 
439  ModelSpec* model_spec = request.mutable_model_spec();
440  model_spec->set_name(kTestModelName);
441  model_spec->mutable_version()->set_value(kTestModelVersion);
442 
443  TensorProto tensor_proto;
444  tensor_proto.add_float_val(2.0);
445  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
446  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
447 
448  auto inter_op_threadpool =
449  std::make_shared<test_util::CountingThreadPool>(Env::Default(), "InterOp",
450  /*num_threads=*/1);
451  auto intra_op_threadpool =
452  std::make_shared<test_util::CountingThreadPool>(Env::Default(), "IntraOp",
453  /*num_threads=*/1);
454  test_util::FakeThreadPoolFactoryConfig fake_thread_pool_factory_config;
455  test_util::FakeThreadPoolFactory fake_thread_pool_factory(
456  fake_thread_pool_factory_config);
457  fake_thread_pool_factory.SetInterOpThreadPool(inter_op_threadpool);
458  fake_thread_pool_factory.SetIntraOpThreadPool(intra_op_threadpool);
459 
460  TensorflowPredictor predictor(&fake_thread_pool_factory);
461  TF_EXPECT_OK(
462  predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
463  TensorProto output_tensor_proto;
464  output_tensor_proto.add_float_val(3);
465  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
466  output_tensor_proto.mutable_tensor_shape();
467  PredictResponse expected_response;
468  *expected_response.mutable_model_spec() = *model_spec;
469  expected_response.mutable_model_spec()->set_signature_name(
470  kDefaultServingSignatureDefKey);
471  (*expected_response.mutable_outputs())[kOutputTensorKey] =
472  output_tensor_proto;
473  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
474 
475  // The intra_op_threadpool doesn't have anything scheduled.
476  ASSERT_GE(inter_op_threadpool->NumScheduled(), 1);
477 }
478 
479 } // namespace
480 } // namespace serving
481 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231