TensorFlow Serving C++ API Documentation
tfrt_predict_util_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/time/clock.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/cc/saved_model/loader.h"
23 #include "tensorflow/cc/saved_model/signature_constants.h"
24 #include "xla/tsl/lib/core/status_test_util.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/threadpool_options.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
31 #include "tensorflow/core/tfrt/utils/tensor_util.h"
32 #include "tensorflow_serving/core/availability_preserving_policy.h"
33 #include "tensorflow_serving/core/servable_handle.h"
34 #include "tensorflow_serving/model_servers/model_platform_types.h"
35 #include "tensorflow_serving/model_servers/platform_config_util.h"
36 #include "tensorflow_serving/model_servers/server_core.h"
37 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/servable.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/servables/tensorflow/test_util/mock_tfrt_saved_model.h"
41 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
42 #include "tensorflow_serving/test_util/test_util.h"
43 #include "tensorflow_serving/util/oss_or_google.h"
44 
45 namespace tensorflow {
46 namespace serving {
47 namespace {
48 using ::testing::_;
49 using ::testing::DoAll;
50 using ::testing::HasSubstr;
51 using ::testing::Return;
52 using ::testing::ReturnRef;
53 using ::testing::WithArgs;
54 
55 constexpr char kTestModelName[] = "test_model";
56 constexpr int kTestModelVersion = 123;
57 
58 const char kInputTensorKey[] = "x";
59 const char kOutputTensorKey[] = "y";
60 
61 class PredictImplTest : public ::testing::Test {
62  public:
63  static void SetUpTestSuite() {
64  tfrt_stub::SetGlobalRuntime(
65  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
66 
67  ModelServerConfig config;
68  auto model_config = config.mutable_model_config_list()->add_config();
69  model_config->set_name(kTestModelName);
70  model_config->set_base_path(
71  test_util::TestSrcDirPath("servables/tensorflow/testdata/"
72  "saved_model_half_plus_two_tf2_cpu"));
73  model_config->set_model_platform(kTensorFlowModelPlatform);
74 
75  // For ServerCore Options, we leave servable_state_monitor_creator
76  // unspecified so the default servable_state_monitor_creator will be used.
77  ServerCore::Options options;
78  options.model_server_config = config;
79  PlatformConfigMap platform_config_map;
80  ::google::protobuf::Any source_adapter_config;
81  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
82  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
83  (*(*platform_config_map
84  .mutable_platform_configs())[kTensorFlowModelPlatform]
85  .mutable_source_adapter_config()) = source_adapter_config;
86  options.platform_config_map = platform_config_map;
87  options.aspired_version_policy =
88  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
89  // Reduce the number of initial load threads to be num_load_threads to avoid
90  // timing out in tests.
91  options.num_initial_load_threads = options.num_load_threads;
92  TF_ASSERT_OK(
93  ServerCore::Create(std::move(options), &saved_model_server_core_));
94  }
95 
96  static void TearDownTestSuite() { saved_model_server_core_.reset(); }
97 
98  protected:
99  Status GetSavedModelServableHandle(ServerCore* server_core,
100  ServableHandle<Servable>* servable) {
101  ModelSpec model_spec;
102  model_spec.set_name(kTestModelName);
103  return server_core->GetServableHandle(model_spec, servable);
104  }
105 
106  ServerCore* GetServerCore() { return saved_model_server_core_.get(); }
107 
108  Status CallPredict(ServerCore* server_core, const PredictRequest& request,
109  PredictResponse* response,
110  absl::Duration timeout = absl::ZeroDuration()) {
111  ServableHandle<Servable> servable;
112  TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &servable));
113 
114  // Set deadline in run options.
115  Servable::RunOptions run_options;
116  if (timeout != absl::ZeroDuration())
117  run_options.deadline = absl::Now() + timeout;
118  return servable->Predict(run_options, request, response);
119  }
120 
121  private:
122  static std::unique_ptr<ServerCore> saved_model_server_core_;
123 };
124 
125 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
126 
127 TEST_F(PredictImplTest, PredictionSuccess) {
128  PredictRequest request;
129  PredictResponse response;
130 
131  ModelSpec* model_spec = request.mutable_model_spec();
132  model_spec->set_name(kTestModelName);
133  model_spec->mutable_version()->set_value(kTestModelVersion);
134 
135  TensorProto tensor_proto;
136  tensor_proto.add_float_val(2.0);
137  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
138  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
139 
140  TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
141  TensorProto output_tensor_proto;
142  output_tensor_proto.add_float_val(3);
143  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
144  output_tensor_proto.mutable_tensor_shape();
145  PredictResponse expected_response;
146  *expected_response.mutable_model_spec() = *model_spec;
147  expected_response.mutable_model_spec()->set_signature_name(
148  kDefaultServingSignatureDefKey);
149  (*expected_response.mutable_outputs())[kOutputTensorKey] =
150  output_tensor_proto;
151  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
152 }
153 
154 TEST_F(PredictImplTest, PredictionSuccessWithDefaultInputs) {
155  PredictRequest request;
156  PredictResponse response;
157 
158  ModelSpec* model_spec = request.mutable_model_spec();
159  model_spec->set_name(kTestModelName);
160  model_spec->mutable_version()->set_value(kTestModelVersion);
161 
162  // prediction result = 0.5x + 2 = 2 with x defaults to 0.
163  TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
164  TensorProto output_tensor_proto;
165  output_tensor_proto.add_float_val(2);
166  output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
167  output_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
168  PredictResponse expected_response;
169  *expected_response.mutable_model_spec() = *model_spec;
170  expected_response.mutable_model_spec()->set_signature_name(
171  kDefaultServingSignatureDefKey);
172  (*expected_response.mutable_outputs())[kOutputTensorKey] =
173  output_tensor_proto;
174  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
175 }
176 
177 TEST_F(PredictImplTest, PredictionInvalidTensor) {
178  PredictRequest request;
179  PredictResponse response;
180 
181  ModelSpec* model_spec = request.mutable_model_spec();
182  model_spec->set_name(kTestModelName);
183  model_spec->mutable_version()->set_value(kTestModelVersion);
184 
185  TensorProto tensor_proto;
186  tensor_proto.add_bool_val(true);
187  tensor_proto.set_dtype(tensorflow::DT_BOOL);
188  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
189 
190  auto status = CallPredict(GetServerCore(), request, &response);
191  EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
192  EXPECT_THAT(status.message(), HasSubstr("Expected input x to be float"));
193 }
194 
195 TEST_F(PredictImplTest, PredictionMissingFunction) {
196  PredictRequest request;
197  PredictResponse response;
198 
199  TensorProto tensor_proto;
200  tensor_proto.add_float_val(2.0);
201  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
202  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
203 
204  std::unique_ptr<test_util::MockSavedModel> saved_model(
205  (new test_util::MockSavedModel()));
206  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
207  .Times(1)
208  .WillRepeatedly(Return(std::nullopt));
209  auto status =
210  RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
211  saved_model.get(), request, &response);
212  EXPECT_EQ(status.code(), tensorflow::error::Code::FAILED_PRECONDITION);
213  EXPECT_THAT(status.message(), HasSubstr("not found"));
214 }
215 
216 TEST_F(PredictImplTest, PredictionMissingInput) {
217  PredictRequest request;
218  request.mutable_model_spec()->set_name(kTestModelName);
219  PredictResponse response;
220 
221  TensorProto tensor_proto;
222  tensor_proto.add_float_val(2.0);
223  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
224  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
225 
226  std::unique_ptr<test_util::MockSavedModel> saved_model(
227  (new test_util::MockSavedModel()));
228  tfrt::internal::Signature signature;
229  signature.input_names = {"unknown"};
230  tfrt::FunctionMetadata function_metadata(&signature);
231  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
232  .Times(1)
233  .WillRepeatedly(Return(function_metadata));
234  auto status =
235  RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
236  saved_model.get(), request, &response);
237  EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
238  EXPECT_THAT(
239  status.message(),
240  HasSubstr(
241  "Request inputs do not match required inputs for model "
242  "`test_model`. Send extra: {x}. Missing but required: {unknown}."));
243 }
244 
245 TEST_F(PredictImplTest, PredictionRunError) {
246  PredictRequest request;
247  PredictResponse response;
248 
249  TensorProto tensor_proto;
250  tensor_proto.add_float_val(2.0);
251  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
252  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
253 
254  std::unique_ptr<test_util::MockSavedModel> saved_model(
255  (new test_util::MockSavedModel()));
256  tfrt::internal::Signature signature;
257  signature.input_names = {"x"};
258  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
259  signature.input_specs = {spec};
260  tfrt::FunctionMetadata function_metadata(&signature);
261  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
262  .Times(1)
263  .WillRepeatedly(Return(function_metadata));
264  EXPECT_CALL(*saved_model,
265  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
266  .Times(1)
267  .WillRepeatedly(Return(errors::InvalidArgument("test error")));
268  auto status =
269  RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
270  saved_model.get(), request, &response);
271  EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
272  EXPECT_THAT(status.message(), HasSubstr("test error"));
273 }
274 
275 TEST_F(PredictImplTest, PredictionUnmatchedOutputNumber) {
276  PredictRequest request;
277  PredictResponse response;
278 
279  TensorProto tensor_proto;
280  tensor_proto.add_float_val(2.0);
281  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
282  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
283 
284  std::unique_ptr<test_util::MockSavedModel> saved_model(
285  (new test_util::MockSavedModel()));
286  tfrt::internal::Signature signature;
287  signature.input_names = {"x"};
288  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
289  signature.input_specs = {spec};
290  tfrt::FunctionMetadata function_metadata(&signature);
291  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
292  .Times(1)
293  .WillRepeatedly(Return(function_metadata));
294 
295  Tensor output;
296  EXPECT_CALL(*saved_model,
297  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
298  .Times(1)
299  .WillRepeatedly(
300  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
301  output_tensors->push_back(output);
302  output_tensors->push_back(output);
303  }),
304  Return(absl::OkStatus())));
305  auto status =
306  RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
307  saved_model.get(), request, &response);
308  EXPECT_EQ(status.code(), tensorflow::error::Code::UNKNOWN);
309  EXPECT_THAT(status.message(), HasSubstr("Predict internal error."));
310 }
311 
312 TEST_F(PredictImplTest, OutputFilters) {
313  PredictRequest request;
314  PredictResponse response;
315 
316  TensorProto tensor_proto;
317  tensor_proto.add_float_val(2.0);
318  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
319  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
320  request.add_output_filter("output1");
321 
322  std::unique_ptr<test_util::MockSavedModel> saved_model(
323  (new test_util::MockSavedModel()));
324  tfrt::internal::Signature signature;
325  signature.input_names = {"x"};
326  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
327  signature.input_specs = {spec};
328  signature.output_names = {"output1", "output2"};
329  tfrt::FunctionMetadata function_metadata(&signature);
330  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
331  .Times(1)
332  .WillRepeatedly(Return(function_metadata));
333 
334  tensorflow::SignatureDef signature_def;
335  tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
336  tensor_info1.set_name("x");
337  tensor_info2.set_name("output1");
338  tensor_info3.set_name("output2");
339  signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
340  signature_def.mutable_outputs()->insert({"output1", tensor_info2});
341  signature_def.mutable_outputs()->insert({"output2", tensor_info3});
342  signature_def.set_method_name("tensorflow/serving/predict");
343 
344  tensorflow::MetaGraphDef meta_graph_def;
345  meta_graph_def.mutable_signature_def()->insert(
346  {"serving_default", signature_def});
347  EXPECT_CALL(*saved_model, GetMetaGraphDef())
348  .Times(1)
349  .WillRepeatedly(ReturnRef(meta_graph_def));
350 
351  TensorProto output_tensor_proto1;
352  output_tensor_proto1.add_float_val(1.0);
353  output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
354  output_tensor_proto1.mutable_tensor_shape();
355 
356  EXPECT_CALL(
357  *saved_model,
358  RunByTensorNames(_, _, ::testing::SizeIs(1),
359  ::testing::An<absl::Span<const std::string>>(), _))
360  .Times(1)
361  .WillRepeatedly(
362  DoAll(WithArgs<4>([&](std::vector<Tensor>* output_tensors) {
363  Tensor output_tensor;
364  CHECK(output_tensor.FromProto(output_tensor_proto1));
365  output_tensors->push_back(output_tensor);
366  }),
367  Return(absl::OkStatus())));
368  TF_EXPECT_OK(RunPredict(tfrt_stub::SavedModel::RunOptions(),
369  kTestModelVersion, saved_model.get(), request,
370  &response));
371  EXPECT_EQ(response.outputs_size(), 1);
372  EXPECT_TRUE(response.outputs().find("output1") != response.outputs().end());
373  EXPECT_THAT(response.outputs().at("output1"),
374  test_util::EqualsProto(output_tensor_proto1));
375 }
376 
377 TEST_F(PredictImplTest, OutputFiltersFullSet) {
378  PredictRequest request;
379  PredictResponse response;
380 
381  TensorProto tensor_proto;
382  tensor_proto.add_float_val(2.0);
383  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
384  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
385  request.add_output_filter("output1");
386  request.add_output_filter("output2");
387 
388  std::unique_ptr<test_util::MockSavedModel> saved_model(
389  (new test_util::MockSavedModel()));
390  tfrt::internal::Signature signature;
391  signature.input_names = {"x"};
392  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
393  signature.input_specs = {spec};
394  signature.output_names = {"output1", "output2"};
395  tfrt::FunctionMetadata function_metadata(&signature);
396  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
397  .Times(1)
398  .WillRepeatedly(Return(function_metadata));
399 
400  // if the output_filter is a full set, we should still call Run(), since full
401  // set is equivalent to an empty filter.
402  EXPECT_CALL(*saved_model, GetMetaGraphDef()).Times(0);
403  EXPECT_CALL(
404  *saved_model,
405  RunByTensorNames(_, _, ::testing::SizeIs(1),
406  ::testing::An<absl::Span<const std::string>>(), _))
407  .Times(0);
408  TensorProto output_tensor_proto1;
409  output_tensor_proto1.add_float_val(1.0);
410  output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
411  output_tensor_proto1.mutable_tensor_shape();
412  TensorProto output_tensor_proto2;
413  output_tensor_proto2.add_float_val(2.0);
414  output_tensor_proto2.set_dtype(tensorflow::DT_FLOAT);
415  output_tensor_proto2.mutable_tensor_shape();
416  EXPECT_CALL(*saved_model, Run(_, _, _, _))
417  .Times(1)
418  .WillRepeatedly(
419  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
420  Tensor output_tensor1;
421  CHECK(output_tensor1.FromProto(output_tensor_proto1));
422  output_tensors->push_back(output_tensor1);
423  Tensor output_tensor2;
424  CHECK(output_tensor2.FromProto(output_tensor_proto2));
425  output_tensors->push_back(output_tensor2);
426  }),
427  Return(absl::OkStatus())));
428 
429  TF_EXPECT_OK(RunPredict(tfrt_stub::SavedModel::RunOptions(),
430  kTestModelVersion, saved_model.get(), request,
431  &response));
432  EXPECT_EQ(response.outputs_size(), 2);
433  EXPECT_TRUE(response.outputs().find("output1") != response.outputs().end());
434  EXPECT_THAT(response.outputs().at("output1"),
435  test_util::EqualsProto(output_tensor_proto1));
436  EXPECT_TRUE(response.outputs().find("output2") != response.outputs().end());
437  EXPECT_THAT(response.outputs().at("output2"),
438  test_util::EqualsProto(output_tensor_proto2));
439 }
440 
441 TEST_F(PredictImplTest, OutputFiltersWithDefaultInputs) {
442  PredictRequest request;
443  PredictResponse response;
444 
445  request.add_output_filter("output1");
446 
447  std::unique_ptr<test_util::MockSavedModel> saved_model(
448  (new test_util::MockSavedModel()));
449  tfrt::internal::Signature signature;
450  signature.input_names = {"x"};
451  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
452  signature.input_specs = {spec};
453  signature.output_names = {"output1", "output2"};
454  Tensor tensor(0);
455  TensorProto tensor_proto;
456  tensor.AsProtoTensorContent(&tensor_proto);
457  signature.default_inputs[kInputTensorKey] = tensor_proto;
458  tfrt::FunctionMetadata function_metadata(&signature);
459  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
460  .Times(1)
461  .WillRepeatedly(Return(function_metadata));
462 
463  tensorflow::SignatureDef signature_def;
464  tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
465  tensor_info1.set_name("x");
466  tensor_info2.set_name("output1");
467  tensor_info3.set_name("output2");
468  signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
469  signature_def.mutable_outputs()->insert({"output1", tensor_info2});
470  signature_def.mutable_outputs()->insert({"output2", tensor_info3});
471  signature_def.set_method_name("tensorflow/serving/predict");
472  (*signature_def.mutable_defaults())[kInputTensorKey] = tensor_proto;
473 
474  tensorflow::MetaGraphDef meta_graph_def;
475  meta_graph_def.mutable_signature_def()->insert(
476  {"serving_default", signature_def});
477  EXPECT_CALL(*saved_model, GetMetaGraphDef())
478  .Times(1)
479  .WillRepeatedly(ReturnRef(meta_graph_def));
480 
481  TensorProto output_tensor_proto1;
482  output_tensor_proto1.add_float_val(1.0);
483  output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
484  output_tensor_proto1.mutable_tensor_shape();
485 
486  EXPECT_CALL(
487  *saved_model,
488  RunByTensorNames(_, ::testing::SizeIs(1), ::testing::SizeIs(1),
489  ::testing::An<absl::Span<const std::string>>(), _))
490  .Times(1)
491  .WillRepeatedly(
492  DoAll(WithArgs<4>([&](std::vector<Tensor>* output_tensors) {
493  Tensor output_tensor;
494  CHECK(output_tensor.FromProto(output_tensor_proto1));
495  output_tensors->push_back(output_tensor);
496  }),
497  Return(absl::OkStatus())));
498  TF_EXPECT_OK(RunPredict(tfrt::SavedModel::RunOptions(), kTestModelVersion,
499  saved_model.get(), request, &response));
500  EXPECT_EQ(response.outputs_size(), 1);
501  EXPECT_TRUE(response.outputs().find("output1") != response.outputs().end());
502  EXPECT_THAT(response.outputs().at("output1"),
503  test_util::EqualsProto(output_tensor_proto1));
504 }
505 
506 TEST_F(PredictImplTest, UnmatchedOutputFilters) {
507  PredictRequest request;
508  PredictResponse response;
509 
510  TensorProto tensor_proto;
511  tensor_proto.add_float_val(2.0);
512  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
513  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
514  request.add_output_filter("output1");
515  request.add_output_filter("output3");
516 
517  std::unique_ptr<test_util::MockSavedModel> saved_model(
518  (new test_util::MockSavedModel()));
519  tfrt::internal::Signature signature;
520  signature.input_names = {"x"};
521  tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
522  signature.input_specs = {spec};
523  signature.output_names = {"output1", "output2"};
524  tfrt::FunctionMetadata function_metadata(&signature);
525  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
526  .Times(1)
527  .WillRepeatedly(Return(function_metadata));
528  Tensor output_tensor;
529 
530  tensorflow::SignatureDef signature_def;
531  tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
532  tensor_info1.set_name("x");
533  tensor_info2.set_name("output1");
534  tensor_info3.set_name("output2");
535  signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
536  signature_def.mutable_outputs()->insert({"output1", tensor_info2});
537  signature_def.mutable_outputs()->insert({"output2", tensor_info3});
538  signature_def.set_method_name("tensorflow/serving/predict");
539 
540  tensorflow::MetaGraphDef meta_graph_def;
541  meta_graph_def.mutable_signature_def()->insert(
542  {"serving_default", signature_def});
543  EXPECT_CALL(*saved_model, GetMetaGraphDef())
544  .Times(1)
545  .WillRepeatedly(ReturnRef(meta_graph_def));
546 
547  auto status =
548  RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
549  saved_model.get(), request, &response);
550  EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
551  EXPECT_THAT(
552  status.message(),
553  HasSubstr("output tensor alias not found in signature: output3 Outputs "
554  "expected to be in the set {output1,output2}."));
555 }
556 
557 TEST_F(PredictImplTest, PredictionTimeout) {
558  PredictRequest request;
559  PredictResponse response;
560 
561  ModelSpec* model_spec = request.mutable_model_spec();
562  model_spec->set_name(kTestModelName);
563  model_spec->mutable_version()->set_value(kTestModelVersion);
564 
565  TensorProto tensor_proto;
566  tensor_proto.add_float_val(2.0);
567  tensor_proto.set_dtype(tensorflow::DT_FLOAT);
568  (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
569 
570  // Set the deadline to be 1 nanosecond from now. This makes the timer
571  // timeout before the request completes.
572  auto status =
573  CallPredict(GetServerCore(), request, &response, absl::Nanoseconds(1));
574 
575  EXPECT_EQ(status.code(), tensorflow::error::Code::DEADLINE_EXCEEDED);
576  EXPECT_EQ(status.message(), "Deadline exceeded.");
577 }
578 
579 } // namespace
580 } // namespace serving
581 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231