TensorFlow Serving C++ API Documentation
tfrt_classifier_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "google/protobuf/map.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/example/example.pb.h"
27 #include "tensorflow/core/example/feature.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/tfrt/utils/tensor_util.h"
36 #include "tsl/platform/env.h"
37 #include "tensorflow_serving/apis/classification.pb.h"
38 #include "tensorflow_serving/apis/input.pb.h"
39 #include "tensorflow_serving/apis/model.pb.h"
40 #include "tensorflow_serving/config/model_server_config.pb.h"
41 #include "tensorflow_serving/core/availability_preserving_policy.h"
42 #include "tensorflow_serving/core/servable_handle.h"
43 #include "tensorflow_serving/model_servers/model_platform_types.h"
44 #include "tensorflow_serving/model_servers/platform_config_util.h"
45 #include "tensorflow_serving/model_servers/server_core.h"
46 #include "tensorflow_serving/servables/tensorflow/servable.h"
47 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
48 #include "tensorflow_serving/servables/tensorflow/test_util/mock_tfrt_saved_model.h"
49 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
50 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
51 #include "tensorflow_serving/test_util/test_util.h"
52 
53 namespace tensorflow {
54 namespace serving {
55 namespace {
56 
57 using ::testing::_;
58 using ::testing::DoAll;
59 using ::testing::HasSubstr;
60 using ::testing::Return;
61 using ::testing::WithArgs;
62 
63 constexpr char kTestModelName[] = "test_model";
64 constexpr int kTestModelVersion = 123;
65 
66 class TfrtClassifierTest : public ::testing::Test {
67  public:
68  static void SetUpTestSuite() {
69  tfrt_stub::SetGlobalRuntime(
70  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
71  }
72 
73  void SetUp() override {
74  ModelServerConfig config;
75  auto model_config = config.mutable_model_config_list()->add_config();
76  model_config->set_name(kTestModelName);
77  model_config->set_base_path(
78  test_util::TestSrcDirPath("servables/tensorflow/"
79  "testdata/saved_model_half_plus_two_cpu"));
80  model_config->set_model_platform(kTensorFlowModelPlatform);
81 
82  // For ServerCore Options, we leave servable_state_monitor_creator
83  // unspecified so the default servable_state_monitor_creator will be used.
84  ServerCore::Options options;
85  options.model_server_config = config;
86  PlatformConfigMap platform_config_map;
87  ::google::protobuf::Any source_adapter_config;
88  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
89  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
90  (*(*platform_config_map
91  .mutable_platform_configs())[kTensorFlowModelPlatform]
92  .mutable_source_adapter_config()) = source_adapter_config;
93  options.platform_config_map = platform_config_map;
94  options.aspired_version_policy =
95  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
96  // Reduce the number of initial load threads to be num_load_threads to avoid
97  // timing out in tests.
98  options.num_initial_load_threads = options.num_load_threads;
99  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
100 
101  request_ = test_util::CreateProto<ClassificationRequest>(
102  "model_spec {"
103  " name: \"test_model\""
104  " signature_name: \"classify_x_to_y\""
105  "}"
106  "input {"
107  " example_list {"
108  " examples {"
109  " features {"
110  " feature: {"
111  " key : \"x\""
112  " value: {"
113  " float_list: {"
114  " value: [ 20.0 ]"
115  " }"
116  " }"
117  " }"
118  " }"
119  " }"
120  " }"
121  "}");
122  }
123 
124  static void TearDownTestSuite() { server_core_ = nullptr; }
125 
126  protected:
127  Status GetSavedModelServableHandle(ServerCore* server_core,
128  ServableHandle<Servable>* servable) {
129  ModelSpec model_spec;
130  model_spec.set_name(kTestModelName);
131  return server_core->GetServableHandle(model_spec, servable);
132  }
133 
134  Status CallClassify(ServerCore* server_core,
135  const ClassificationRequest& request,
136  ClassificationResponse* response) {
137  ServableHandle<Servable> servable;
138  TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &servable));
139  return servable->Classify({}, request, response);
140  }
141 
142  // Classifier valid after calling create.
143  std::unique_ptr<ClassifierInterface> classifier_;
144  static std::unique_ptr<ServerCore> server_core_;
145  ClassificationRequest request_;
146 };
147 
148 std::unique_ptr<ServerCore> TfrtClassifierTest::server_core_;
149 
150 TEST_F(TfrtClassifierTest, Basic) {
151  auto request = test_util::CreateProto<ClassificationRequest>(
152  "model_spec {"
153  " name: \"test_model\""
154  " signature_name: \"classify_x_to_y\""
155  "}"
156  "input {"
157  " example_list {"
158  " examples {"
159  " features {"
160  " feature: {"
161  " key : \"x\""
162  " value: {"
163  " float_list: {"
164  " value: [ 80.0 ]"
165  " }"
166  " }"
167  " }"
168  " feature: {"
169  " key : \"locale\""
170  " value: {"
171  " bytes_list: {"
172  " value: [ \"pt_BR\" ]"
173  " }"
174  " }"
175  " }"
176  " feature: {"
177  " key : \"age\""
178  " value: {"
179  " float_list: {"
180  " value: [ 19.0 ]"
181  " }"
182  " }"
183  " }"
184  " }"
185  " }"
186  " examples {"
187  " features {"
188  " feature: {"
189  " key : \"x\""
190  " value: {"
191  " float_list: {"
192  " value: [ 20.0 ]"
193  " }"
194  " }"
195  " }"
196  " }"
197  " }"
198  " }"
199  "}");
200  ClassificationResponse response;
201 
202  TF_EXPECT_OK(CallClassify(server_core_.get(), request, &response));
203  EXPECT_THAT(response,
204  test_util::EqualsProto(
205  "result { classifications { classes { "
206  "score: 42 } } classifications { classes {score: 12 } } }"
207  "model_spec {"
208  " name: \"test_model\""
209  " signature_name: \"classify_x_to_y\""
210  " version { value: 123 }"
211  "}"));
212 }
213 
214 TEST_F(TfrtClassifierTest, BasicWithContext) {
215  auto request = test_util::CreateProto<ClassificationRequest>(
216  "model_spec {"
217  " name: \"test_model\""
218  " signature_name: \"classify_x_to_y\""
219  "}"
220  "input {"
221  " example_list_with_context {"
222  " examples {"
223  " features {"
224  " feature: {"
225  " key : \"x\""
226  " value: {"
227  " float_list: {"
228  " value: [ 80.0 ]"
229  " }"
230  " }"
231  " }"
232  " feature: {"
233  " key : \"locale\""
234  " value: {"
235  " bytes_list: {"
236  " value: [ \"pt_BR\" ]"
237  " }"
238  " }"
239  " }"
240  " feature: {"
241  " key : \"age\""
242  " value: {"
243  " float_list: {"
244  " value: [ 19.0 ]"
245  " }"
246  " }"
247  " }"
248  " }"
249  " }"
250  " examples {"
251  " features {"
252  " feature: {"
253  " key : \"x\""
254  " value: {"
255  " float_list: {"
256  " value: [ 20.0 ]"
257  " }"
258  " }"
259  " }"
260  " }"
261  " }"
262  " context: {"
263  " features: {"
264  " feature: {"
265  " key : \"x\""
266  " value: {"
267  " float_list: {"
268  " value: [ 10.0 ]"
269  " }"
270  " }"
271  " }"
272  " }"
273  " }"
274  " }"
275  "}");
276  ClassificationResponse response;
277 
278  TF_EXPECT_OK(CallClassify(server_core_.get(), request, &response));
279  EXPECT_THAT(response, test_util::EqualsProto(
280  "result { classifications { classes { score: 42 }} "
281  "classifications { classes { score: 12 }}}"
282  "model_spec {"
283  " name: \"test_model\""
284  " signature_name: \"classify_x_to_y\""
285  " version { value: 123 }"
286  "}"));
287 }
288 
289 TEST_F(TfrtClassifierTest, EmptyExampleList) {
290  auto request = test_util::CreateProto<ClassificationRequest>(
291  "model_spec {"
292  " name: \"test_model\""
293  " signature_name: \"classify_x_to_y\""
294  "}"
295  "input {"
296  " example_list {"
297  " }"
298  "}");
299  ClassificationResponse response;
300 
301  Status status = CallClassify(server_core_.get(), request, &response);
302  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
303  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
304 }
305 
306 TEST_F(TfrtClassifierTest, EmptyExampleListWithContext) {
307  auto request = test_util::CreateProto<ClassificationRequest>(
308  "model_spec {"
309  " name: \"test_model\""
310  " signature_name: \"classify_x_to_y\""
311  "}"
312  "input {"
313  " example_list_with_context {"
314  " context: {"
315  " features: {"
316  " feature: {"
317  " key : \"x\""
318  " value: {"
319  " float_list: {"
320  " value: [ 10.0 ]"
321  " }"
322  " }"
323  " }"
324  " }"
325  " }"
326  " }"
327  "}");
328  ClassificationResponse response;
329 
330  Status status = CallClassify(server_core_.get(), request, &response);
331  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
332  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
333 }
334 
335 TEST_F(TfrtClassifierTest, EmptyInput) {
336  auto request = test_util::CreateProto<ClassificationRequest>(
337  "model_spec {"
338  " name: \"test_model\""
339  " signature_name: \"classify_x_to_y\""
340  "}"
341  "input {"
342  "}");
343  ClassificationResponse response;
344 
345  Status status = CallClassify(server_core_.get(), request, &response);
346  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
347  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
348 }
349 
350 TEST_F(TfrtClassifierTest, InvalidFunctionName) {
351  ClassificationResponse response;
352  std::unique_ptr<test_util::MockSavedModel> saved_model(
353  (new test_util::MockSavedModel()));
354  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
355  .Times(1)
356  .WillRepeatedly(Return(std::nullopt));
357  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
358  saved_model.get(), request_, &response);
359  EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
360  EXPECT_THAT(status.message(), HasSubstr("not found"));
361 }
362 
363 TEST_F(TfrtClassifierTest, InvalidFunctionUnmatchedInputSize) {
364  ClassificationResponse response;
365  std::unique_ptr<test_util::MockSavedModel> saved_model(
366  (new test_util::MockSavedModel()));
367  tfrt::internal::Signature signature;
368  signature.input_names = {kClassifyInputs, "wrong input"};
369  signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
370  tfrt::FunctionMetadata function_metadata(&signature);
371  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
372  .Times(1)
373  .WillRepeatedly(Return(function_metadata));
374  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
375  saved_model.get(), request_, &response);
376  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
377  EXPECT_THAT(status.message(), HasSubstr("Expected one input Tensor."));
378 }
379 
380 TEST_F(TfrtClassifierTest, InvalidFunctionUnmatchedOutputSize) {
381  ClassificationResponse response;
382  std::unique_ptr<test_util::MockSavedModel> saved_model(
383  (new test_util::MockSavedModel()));
384  tfrt::internal::Signature signature;
385  signature.input_names = {kClassifyInputs};
386  signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores,
387  "wrong output"};
388  tfrt::FunctionMetadata function_metadata(&signature);
389  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
390  .Times(1)
391  .WillRepeatedly(Return(function_metadata));
392  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
393  saved_model.get(), request_, &response);
394  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
395  EXPECT_THAT(status.message(),
396  HasSubstr("Expected one or two output Tensors"));
397 }
398 
399 TEST_F(TfrtClassifierTest, InvalidFunctionInvalidInputName) {
400  ClassificationResponse response;
401  std::unique_ptr<test_util::MockSavedModel> saved_model(
402  (new test_util::MockSavedModel()));
403  tfrt::internal::Signature signature;
404  signature.input_names = {"wrong input"};
405  signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
406  tfrt::FunctionMetadata function_metadata(&signature);
407  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
408  .Times(1)
409  .WillRepeatedly(Return(function_metadata));
410  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
411  saved_model.get(), request_, &response);
412  EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
413  EXPECT_THAT(
414  status.message(),
415  HasSubstr("No classification inputs found in function's metadata"));
416 }
417 
418 TEST_F(TfrtClassifierTest, InvalidFunctionInvalidOutputName) {
419  ClassificationResponse response;
420  std::unique_ptr<test_util::MockSavedModel> saved_model(
421  (new test_util::MockSavedModel()));
422  tfrt::internal::Signature signature;
423  signature.input_names = {kClassifyInputs};
424  signature.output_names = {"wrong output", kClassifyOutputScores};
425  tfrt::FunctionMetadata function_metadata(&signature);
426  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
427  .Times(1)
428  .WillRepeatedly(Return(function_metadata));
429  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
430  saved_model.get(), request_, &response);
431  EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
432  EXPECT_THAT(status.message(),
433  HasSubstr("Expected classification function outputs to contain"));
434 }
435 
436 TEST_F(TfrtClassifierTest, RunsFails) {
437  ClassificationResponse response;
438  std::unique_ptr<test_util::MockSavedModel> saved_model(
439  (new test_util::MockSavedModel()));
440  tfrt::internal::Signature signature;
441  signature.input_names = {kClassifyInputs};
442  signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
443  tfrt::FunctionMetadata function_metadata(&signature);
444  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
445  .Times(1)
446  .WillRepeatedly(Return(function_metadata));
447  EXPECT_CALL(*saved_model,
448  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
449  .Times(1)
450  .WillRepeatedly(Return(errors::InvalidArgument("test error")));
451  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
452  saved_model.get(), request_, &response);
453  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
454  EXPECT_THAT(status.message(), HasSubstr("test error"));
455 }
456 
457 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorNumber) {
458  ClassificationResponse response;
459  std::unique_ptr<test_util::MockSavedModel> saved_model(
460  (new test_util::MockSavedModel()));
461  tfrt::internal::Signature signature;
462  signature.input_names = {kClassifyInputs};
463  signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
464  tfrt::FunctionMetadata function_metadata(&signature);
465  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
466  .Times(1)
467  .WillRepeatedly(Return(function_metadata));
468  Tensor output;
469  EXPECT_CALL(*saved_model,
470  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
471  .Times(1)
472  .WillRepeatedly(
473  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
474  output_tensors->push_back(output);
475  }),
476  Return(absl::OkStatus())));
477  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
478  saved_model.get(), request_, &response);
479  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
480  EXPECT_THAT(status.message(), HasSubstr("Unexpected output tensors size"));
481 }
482 
483 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorShape) {
484  ClassificationResponse response;
485  std::unique_ptr<test_util::MockSavedModel> saved_model(
486  (new test_util::MockSavedModel()));
487  tfrt::internal::Signature signature;
488  signature.input_names = {kClassifyInputs};
489  signature.output_names = {kClassifyOutputScores};
490  tfrt::FunctionMetadata function_metadata(&signature);
491  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
492  .Times(1)
493  .WillRepeatedly(Return(function_metadata));
494  Tensor output(DT_FLOAT, TensorShape({1, 1, 1}));
495  EXPECT_CALL(*saved_model,
496  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
497  .Times(1)
498  .WillRepeatedly(
499  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
500  output_tensors->push_back(output);
501  }),
502  Return(absl::OkStatus())));
503  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
504  saved_model.get(), request_, &response);
505  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
506  EXPECT_THAT(status.message(), HasSubstr("Expected Tensor shape"));
507 }
508 
509 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorType) {
510  ClassificationResponse response;
511  std::unique_ptr<test_util::MockSavedModel> saved_model(
512  (new test_util::MockSavedModel()));
513  tfrt::internal::Signature signature;
514  signature.input_names = {kClassifyInputs};
515  signature.output_names = {kClassifyOutputScores};
516  tfrt::FunctionMetadata function_metadata(&signature);
517  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
518  .Times(1)
519  .WillRepeatedly(Return(function_metadata));
520  Tensor output(DT_STRING, TensorShape({1, 1}));
521  EXPECT_CALL(*saved_model,
522  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
523  .Times(1)
524  .WillRepeatedly(
525  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
526  output_tensors->push_back(output);
527  }),
528  Return(absl::OkStatus())));
529  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
530  saved_model.get(), request_, &response);
531  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
532  EXPECT_THAT(status.message(),
533  HasSubstr("Expected scores Tensor of DT_FLOAT"));
534 }
535 
536 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorSize) {
537  ClassificationResponse response;
538  std::unique_ptr<test_util::MockSavedModel> saved_model(
539  (new test_util::MockSavedModel()));
540  tfrt::internal::Signature signature;
541  signature.input_names = {kClassifyInputs};
542  signature.output_names = {kClassifyOutputScores};
543  tfrt::FunctionMetadata function_metadata(&signature);
544  EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
545  .Times(1)
546  .WillRepeatedly(Return(function_metadata));
547  Tensor output(DT_FLOAT, TensorShape({10, 1}));
548  EXPECT_CALL(*saved_model,
549  Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
550  .Times(1)
551  .WillRepeatedly(
552  DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
553  output_tensors->push_back(output);
554  }),
555  Return(absl::OkStatus())));
556  auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
557  saved_model.get(), request_, &response);
558  EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
559  EXPECT_THAT(status.message(),
560  HasSubstr("Expected scores output batch size of"));
561 }
562 
563 } // namespace
564 } // namespace serving
565 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231