TensorFlow Serving C++ API Documentation
predict_util_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/threadpool_options.h"
29 #include "tensorflow_serving/core/availability_preserving_policy.h"
30 #include "tensorflow_serving/model_servers/model_platform_types.h"
31 #include "tensorflow_serving/model_servers/platform_config_util.h"
32 #include "tensorflow_serving/model_servers/server_core.h"
33 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/util.h"
36 #include "tensorflow_serving/test_util/test_util.h"
37 #include "tensorflow_serving/util/oss_or_google.h"
38 
39 namespace tensorflow {
40 namespace serving {
41 namespace {
42 
43 constexpr char kTestModelName[] = "test_model";
44 constexpr int kTestModelVersion = 123;
45 
46 const char kInputTensorKey[] = "x";
47 const char kOutputTensorKey[] = "y";
48 
49 // Fake Session, that copies input tensors to output.
50 class FakeSession : public tensorflow::Session {
51  public:
52  FakeSession() {}
53  ~FakeSession() override = default;
54  Status Create(const GraphDef& graph) override {
55  return errors::Unimplemented("not available in fake");
56  }
57  Status Extend(const GraphDef& graph) override {
58  return errors::Unimplemented("not available in fake");
59  }
60  Status Close() override {
61  return errors::Unimplemented("not available in fake");
62  }
63  Status ListDevices(std::vector<DeviceAttributes>* response) override {
64  return errors::Unimplemented("not available in fake");
65  }
66  Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
67  const std::vector<string>& output_names,
68  const std::vector<string>& target_nodes,
69  std::vector<Tensor>* outputs) override {
70  RunMetadata run_metadata;
71  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
72  &run_metadata);
73  }
74  Status Run(const RunOptions& run_options,
75  const std::vector<std::pair<string, Tensor>>& inputs,
76  const std::vector<string>& output_names,
77  const std::vector<string>& target_nodes,
78  std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
79  return Run(run_options, inputs, output_names, target_nodes, outputs,
80  run_metadata, thread::ThreadPoolOptions());
81  }
82  Status Run(const RunOptions& run_options,
83  const std::vector<std::pair<string, Tensor>>& inputs,
84  const std::vector<string>& output_names,
85  const std::vector<string>& target_nodes,
86  std::vector<Tensor>* outputs, RunMetadata* run_metadata,
87  const thread::ThreadPoolOptions& thread_pool_options) override {
88  for (const auto& t : inputs) {
89  outputs->push_back(t.second);
90  }
91  return absl::OkStatus();
92  }
93 };
94 
95 class PredictImplTest : public ::testing::Test {
96  public:
97  static void SetUpTestSuite() {
98  if (!IsTensorflowServingOSS()) {
99  const string bad_half_plus_two_path = test_util::TestSrcDirPath(
100  "/servables/tensorflow/testdata/bad_half_plus_two");
101  TF_ASSERT_OK(CreateServerCore(bad_half_plus_two_path,
102  &saved_model_server_core_bad_model_));
103  }
104 
105  TF_ASSERT_OK(CreateServerCore(test_util::TensorflowTestSrcDirPath(
106  "cc/saved_model/testdata/half_plus_two"),
107  &saved_model_server_core_));
108  TF_ASSERT_OK(CreateServerCore(
109  test_util::TestSrcDirPath(
110  "/servables/tensorflow/testdata/saved_model_counter"),
111  &saved_model_server_core_counter_model_));
112  }
113 
114  static void TearDownTestSuite() {
115  saved_model_server_core_.reset();
116  saved_model_server_core_bad_model_.reset();
117  saved_model_server_core_counter_model_.reset();
118  }
119 
120  protected:
121  static Status CreateServerCore(const string& model_path,
122  std::unique_ptr<ServerCore>* server_core) {
123  ModelServerConfig config;
124  auto model_config = config.mutable_model_config_list()->add_config();
125  model_config->set_name(kTestModelName);
126  model_config->set_base_path(model_path);
127  model_config->set_model_platform(kTensorFlowModelPlatform);
128 
129  // For ServerCore Options, we leave servable_state_monitor_creator
130  // unspecified so the default servable_state_monitor_creator will be used.
131  ServerCore::Options options;
132  options.model_server_config = config;
133  options.platform_config_map =
134  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
135  options.aspired_version_policy =
136  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
137  // Reduce the number of initial load threads to be num_load_threads to avoid
138  // timing out in tests.
139  options.num_initial_load_threads = options.num_load_threads;
140  return ServerCore::Create(std::move(options), server_core);
141  }
142 
143  ServerCore* GetServerCore() {
144  return saved_model_server_core_.get();
145  }
146 
147  ServerCore* GetServerCoreWithBadModel() {
148  return saved_model_server_core_bad_model_.get();
149  }
150 
151  ServerCore* GetServerCoreWithCounterModel() {
152  return saved_model_server_core_counter_model_.get();
153  }
154 
155  Status GetSavedModelServableHandle(ServerCore* server_core,
156  ServableHandle<SavedModelBundle>* bundle) {
157  ModelSpec model_spec;
158  model_spec.set_name(kTestModelName);
159  return server_core->GetServableHandle(model_spec, bundle);
160  }
161 
162  Status CallPredict(ServerCore* server_core, const PredictRequest& request,
163  PredictResponse* response,
164  const thread::ThreadPoolOptions& thread_pool_options =
165  thread::ThreadPoolOptions()) {
166  ServableHandle<SavedModelBundle> bundle;
167  TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &bundle));
168  return RunPredict(GetRunOptions(), bundle->meta_graph_def,
169  kTestModelVersion, bundle->session.get(), request,
170  response, thread_pool_options);
171  }
172 
173  RunOptions GetRunOptions() { return RunOptions(); }
174 
175  private:
176  static std::unique_ptr<ServerCore> saved_model_server_core_;
177  static std::unique_ptr<ServerCore> saved_model_server_core_bad_model_;
178  static std::unique_ptr<ServerCore> saved_model_server_core_counter_model_;
179 };
180 
181 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
182 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_bad_model_;
183 std::unique_ptr<ServerCore>
184  PredictImplTest::saved_model_server_core_counter_model_;
185 
186 TEST_F(PredictImplTest, MissingOrEmptyModelSpec) {
187  PredictRequest request;
188  PredictResponse response;
189 
190  // Empty request is invalid.
191  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
192  CallPredict(GetServerCore(), request, &response).code());
193 
194  ModelSpec* model_spec = request.mutable_model_spec();
195  model_spec->clear_name();
196 
197  // Model name is not specified.
198  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
199  CallPredict(GetServerCore(), request, &response).code());
200 
201  // Model name is wrong.
202  model_spec->set_name("test");
203  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
204  CallPredict(GetServerCore(), request, &response).code());
205 }
206 
207 TEST_F(PredictImplTest, EmptyInputList) {
208  PredictRequest request;
209  PredictResponse response;
210 
211  ModelSpec* model_spec = request.mutable_model_spec();
212  model_spec->set_name(kTestModelName);
213  model_spec->mutable_version()->set_value(kTestModelVersion);
214 
215  // The input is empty.
216  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
217  CallPredict(GetServerCore(), request, &response).code());
218 }
219 
220 TEST_F(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
221  PredictRequest request;
222  PredictResponse response;
223  auto inputs = request.mutable_inputs();
224 
225  ModelSpec* model_spec = request.mutable_model_spec();
226  model_spec->set_name(kTestModelName);
227  model_spec->mutable_version()->set_value(kTestModelVersion);
228 
229  TensorProto tensor_proto1;
230  tensor_proto1.add_string_val("any_value");
231  tensor_proto1.set_dtype(tensorflow::DT_STRING);
232  tensor_proto1.mutable_tensor_shape()->add_dim()->set_size(1);
233  (*inputs)["unknown_key1"] = tensor_proto1;
234 
235  TensorProto tensor_proto2;
236  tensor_proto2.add_float_val(1.0);
237  tensor_proto2.set_dtype(tensorflow::DT_FLOAT);
238  tensor_proto2.mutable_tensor_shape()->add_dim()->set_size(1);
239  (*inputs)["unknown_key2"] = tensor_proto2;
240 
241  Status status = CallPredict(GetServerCore(), request, &response);
242  EXPECT_EQ(status.code(),
243  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
244  EXPECT_THAT(status.message(),
245  ::testing::HasSubstr("Sent extra: {unknown_key1,unknown_key2}"));
246  EXPECT_THAT(status.message(),
247  ::testing::HasSubstr(absl::StrCat("Missing but required: {",
248  kInputTensorKey, "}")));
249 }
250 
251 TEST_F(PredictImplTest, PredictionInvalidTensor) {
252  PredictRequest request;
253  PredictResponse response;
254 
255  ModelSpec* model_spec = request.mutable_model_spec();
256  model_spec->set_name(kTestModelName);
257  model_spec->mutable_version()->set_value(kTestModelVersion);
258 
259  TensorProto tensor_proto;
260  tensor_proto.add_bool_val(true);
261  tensor_proto.set_dtype(tensorflow::DT_BOOL);
262  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
263 
264  auto status = CallPredict(GetServerCore(), request, &response);
265  EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
266  EXPECT_THAT(
267  status.message(),
268  ::testing::HasSubstr("Expects arg[0] to be float but bool is provided"));
269 }
270 
271 TEST_F(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
272  PredictRequest request;
273  PredictResponse response;
274 
275  ModelSpec* model_spec = request.mutable_model_spec();
276  model_spec->set_name(kTestModelName);
277  model_spec->mutable_version()->set_value(kTestModelVersion);
278 
279  TensorProto tensor_proto;
280  tensor_proto.add_float_val(2.0);
281  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
282  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
283  request.add_output_filter("output_filter");
284 
285  // Output filter like this doesn't exist.
286  Status status1 = CallPredict(GetServerCore(), request, &response);
287  EXPECT_EQ(status1.code(),
288  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
289  EXPECT_THAT(status1.message(),
290  ::testing::HasSubstr(
291  "output tensor alias not found in signature: output_filter"));
292 
293  request.clear_output_filter();
294  request.add_output_filter(kOutputTensorKey);
295  TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
296  request.add_output_filter(kOutputTensorKey);
297 
298  // Duplicate output filter specified.
299  Status status2 = CallPredict(GetServerCore(), request, &response);
300  EXPECT_EQ(status2.code(),
301  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
302  EXPECT_THAT(status2.message(),
303  ::testing::HasSubstr("duplicate output tensor alias: y"));
304 }
305 
306 TEST_F(PredictImplTest, InputTensorsHaveWrongType) {
307  PredictRequest request;
308  PredictResponse response;
309 
310  ModelSpec* model_spec = request.mutable_model_spec();
311  model_spec->set_name(kTestModelName);
312  model_spec->mutable_version()->set_value(kTestModelVersion);
313 
314  TensorProto tensor_proto;
315  tensor_proto.add_string_val("any_value");
316  tensor_proto.set_dtype(tensorflow::DT_STRING);
317  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
318  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
319  request.add_output_filter(kOutputTensorKey);
320 
321  // Input tensors are all wrong.
322  Status status = CallPredict(GetServerCore(), request, &response);
323  EXPECT_EQ(status.code(),
324  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
325  EXPECT_THAT(status.message(),
326  ::testing::HasSubstr("to be float but string is provided"));
327 }
328 
329 TEST_F(PredictImplTest, ModelMissingSignatures) {
330  if (IsTensorflowServingOSS()) {
331  return;
332  }
333  PredictRequest request;
334  PredictResponse response;
335 
336  ModelSpec* model_spec = request.mutable_model_spec();
337  model_spec->set_name(kTestModelName);
338  model_spec->mutable_version()->set_value(kTestModelVersion);
339 
340  // Model is missing signatures.
341  EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION,
342  CallPredict(GetServerCoreWithBadModel(),
343  request, &response).code());
344 }
345 
346 TEST_F(PredictImplTest, PredictionSuccess) {
347  PredictRequest request;
348  PredictResponse response;
349 
350  ModelSpec* model_spec = request.mutable_model_spec();
351  model_spec->set_name(kTestModelName);
352  model_spec->mutable_version()->set_value(kTestModelVersion);
353 
354  TensorProto tensor_proto;
355  tensor_proto.add_float_val(2.0);
356  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
357  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
358 
359  TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
360  TensorProto output_tensor_proto;
361  output_tensor_proto.add_float_val(3);
362  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
363  output_tensor_proto.mutable_tensor_shape();
364  PredictResponse expected_response;
365  *expected_response.mutable_model_spec() = *model_spec;
366  expected_response.mutable_model_spec()->set_signature_name(
367  kDefaultServingSignatureDefKey);
368  (*expected_response.mutable_outputs())[kOutputTensorKey] =
369  output_tensor_proto;
370  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
371 }
372 
373 // Test querying a model with a named regression signature (not default). This
374 TEST_F(PredictImplTest, PredictionWithNamedRegressionSignature) {
375  PredictRequest request;
376  PredictResponse response;
377 
378  ModelSpec* model_spec = request.mutable_model_spec();
379  model_spec->set_name(kTestModelName);
380  model_spec->mutable_version()->set_value(kTestModelVersion);
381  model_spec->set_signature_name("regress_x2_to_y3");
382 
383  TensorProto tensor_proto;
384  tensor_proto.add_float_val(2.0);
385  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
386  (*request.mutable_inputs())[kRegressInputs] = tensor_proto;
387  TF_ASSERT_OK(CallPredict(GetServerCore(), request, &response));
388  TensorProto output_tensor_proto;
389  output_tensor_proto.add_float_val(4);
390  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
391  output_tensor_proto.mutable_tensor_shape();
392  PredictResponse expected_response;
393  *expected_response.mutable_model_spec() = *model_spec;
394  (*expected_response.mutable_outputs())[kRegressOutputs] = output_tensor_proto;
395  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
396 }
397 
398 // Test querying a model with a classification signature. Predict calls work
399 // with predict, classify, and regress signatures.
400 TEST_F(PredictImplTest, PredictionWithNamedClassificationSignature) {
401  PredictRequest request;
402  PredictResponse response;
403 
404  ModelSpec* model_spec = request.mutable_model_spec();
405  model_spec->set_name(kTestModelName);
406  model_spec->mutable_version()->set_value(kTestModelVersion);
407  model_spec->set_signature_name("classify_x2_to_y3");
408 
409  TensorProto tensor_proto;
410  tensor_proto.add_float_val(2.0);
411  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
412  (*request.mutable_inputs())[kClassifyInputs] = tensor_proto;
413 
414  TF_ASSERT_OK(CallPredict(GetServerCore(), request, &response));
415  TensorProto output_tensor_proto;
416  output_tensor_proto.add_float_val(4);
417  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
418  output_tensor_proto.mutable_tensor_shape();
419  PredictResponse expected_response;
420  *expected_response.mutable_model_spec() = *model_spec;
421  (*expected_response.mutable_outputs())[kClassifyOutputScores] =
422  output_tensor_proto;
423  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
424 }
425 
426 // Test querying a counter model with signatures. Predict calls work with
427 // customized signatures. It calls get_counter, incr_counter,
428 // reset_counter, incr_counter, and incr_counter_by(3) in order.
429 //
430 // *Notes*: These signatures are stateful and over-simplied only to demonstrate
431 // Predict calls with only inputs or outputs. State is not supported in
432 // TensorFlow Serving on most scalable or production hosting environments.
433 TEST_F(PredictImplTest, PredictionWithCustomizedSignatures) {
434  PredictRequest request;
435  PredictResponse response;
436 
437  // Call get_counter. Expected result 0.
438  ModelSpec* model_spec = request.mutable_model_spec();
439  model_spec->set_name(kTestModelName);
440  model_spec->mutable_version()->set_value(kTestModelVersion);
441  model_spec->set_signature_name("get_counter");
442 
443  TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
444  request, &response));
445 
446  PredictResponse expected_get_counter;
447  *expected_get_counter.mutable_model_spec() = *model_spec;
448  TensorProto output_get_counter;
449  output_get_counter.add_float_val(0);
450  output_get_counter.set_dtype(tensorflow::DT_FLOAT);
451  output_get_counter.mutable_tensor_shape();
452  (*expected_get_counter.mutable_outputs())["output"] = output_get_counter;
453  EXPECT_THAT(response, test_util::EqualsProto(expected_get_counter));
454 
455  // Call incr_counter. Expect: 1.
456  model_spec->set_signature_name("incr_counter");
457  TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
458  request, &response));
459 
460  PredictResponse expected_incr_counter;
461  *expected_incr_counter.mutable_model_spec() = *model_spec;
462  TensorProto output_incr_counter;
463  output_incr_counter.add_float_val(1);
464  output_incr_counter.set_dtype(tensorflow::DT_FLOAT);
465  output_incr_counter.mutable_tensor_shape();
466  (*expected_incr_counter.mutable_outputs())["output"] = output_incr_counter;
467  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter));
468 
469  // Call reset_counter. Expect: 0.
470  model_spec->set_signature_name("reset_counter");
471  TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
472  request, &response));
473 
474  PredictResponse expected_reset_counter;
475  *expected_reset_counter.mutable_model_spec() = *model_spec;
476  TensorProto output_reset_counter;
477  output_reset_counter.add_float_val(0);
478  output_reset_counter.set_dtype(tensorflow::DT_FLOAT);
479  output_reset_counter.mutable_tensor_shape();
480  (*expected_reset_counter.mutable_outputs())["output"] = output_reset_counter;
481  EXPECT_THAT(response, test_util::EqualsProto(expected_reset_counter));
482 
483  // Call incr_counter. Expect: 1.
484  model_spec->set_signature_name("incr_counter");
485  request.add_output_filter("output");
486  TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
487  request, &response));
488  request.clear_output_filter();
489 
490  PredictResponse expected_incr_counter2;
491  *expected_incr_counter2.mutable_model_spec() = *model_spec;
492  TensorProto output_incr_counter2;
493  output_incr_counter2.add_float_val(1);
494  output_incr_counter2.set_dtype(tensorflow::DT_FLOAT);
495  output_incr_counter2.mutable_tensor_shape();
496  (*expected_incr_counter2.mutable_outputs())["output"] = output_incr_counter2;
497  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter2));
498 
499  // Call incr_counter_by. Expect: 4.
500  model_spec->set_signature_name("incr_counter_by");
501  TensorProto tensor_proto;
502  tensor_proto.add_float_val(3);
503  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
504  (*request.mutable_inputs())["delta"] = tensor_proto;
505 
506  TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
507  request, &response));
508 
509  PredictResponse expected_incr_counter_by;
510  *expected_incr_counter_by.mutable_model_spec() = *model_spec;
511  TensorProto output_incr_counter_by;
512  output_incr_counter_by.add_float_val(4);
513  output_incr_counter_by.set_dtype(tensorflow::DT_FLOAT);
514  output_incr_counter_by.mutable_tensor_shape();
515  (*expected_incr_counter_by.mutable_outputs())["output"] =
516  output_incr_counter_by;
517  EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter_by));
518 }
519 
520 TEST_F(PredictImplTest, ThreadPoolOptions) {
521  PredictRequest request;
522  PredictResponse response;
523 
524  ModelSpec* model_spec = request.mutable_model_spec();
525  model_spec->set_name(kTestModelName);
526  model_spec->mutable_version()->set_value(kTestModelVersion);
527 
528  TensorProto tensor_proto;
529  tensor_proto.add_float_val(2.0);
530  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
531  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
532 
533  test_util::CountingThreadPool inter_op_threadpool(Env::Default(), "InterOp",
534  /*num_threads=*/1);
535  test_util::CountingThreadPool intra_op_threadpool(Env::Default(), "IntraOp",
536  /*num_threads=*/1);
537  thread::ThreadPoolOptions thread_pool_options;
538  thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
539  thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
540  TF_EXPECT_OK(
541  CallPredict(GetServerCore(), request, &response, thread_pool_options));
542  TensorProto output_tensor_proto;
543  output_tensor_proto.add_float_val(3);
544  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
545  output_tensor_proto.mutable_tensor_shape();
546  PredictResponse expected_response;
547  *expected_response.mutable_model_spec() = *model_spec;
548  expected_response.mutable_model_spec()->set_signature_name(
549  kDefaultServingSignatureDefKey);
550  (*expected_response.mutable_outputs())[kOutputTensorKey] =
551  output_tensor_proto;
552  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
553 
554  // The intra_op_threadpool doesn't have anything scheduled.
555  ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
556 }
557 
558 TEST_F(PredictImplTest, MethodNameCheck) {
559  ServableHandle<SavedModelBundle> bundle;
560  TF_ASSERT_OK(GetSavedModelServableHandle(GetServerCore(), &bundle));
561  MetaGraphDef meta_graph_def = bundle->meta_graph_def;
562  auto* signature_defs = meta_graph_def.mutable_signature_def();
563 
564  PredictRequest request;
565  ModelSpec* model_spec = request.mutable_model_spec();
566  model_spec->set_name(kTestModelName);
567  model_spec->mutable_version()->set_value(kTestModelVersion);
568  TensorProto tensor_proto;
569  tensor_proto.add_float_val(2.0);
570  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
571  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
572 
573  FakeSession fake_session;
574  PredictResponse response;
575 
576  bool old_val = GetSignatureMethodNameCheckFeature();
577 
578  SetSignatureMethodNameCheckFeature(true);
579  // Legit method name.
580  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
581  kPredictMethodName);
582  TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
583  &fake_session, request, &response,
584  thread::ThreadPoolOptions()));
585  // Unsupported method name will fail check.
586  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
587  "not/supported/method");
588  EXPECT_FALSE(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
589  &fake_session, request, &response,
590  thread::ThreadPoolOptions())
591  .ok());
592 
593  SetSignatureMethodNameCheckFeature(false);
594  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
595  kPredictMethodName);
596  TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
597  &fake_session, request, &response,
598  thread::ThreadPoolOptions()));
599  // Unsupported method name should also work.
600  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
601  "not/supported/method");
602  TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
603  &fake_session, request, &response,
604  thread::ThreadPoolOptions()));
605 
606  SetSignatureMethodNameCheckFeature(old_val);
607 }
608 
609 } // namespace
610 } // namespace serving
611 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231