TensorFlow Serving C++ API Documentation
tfrt_classification_service_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_classification_service.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/threadpool_options.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 #include "tensorflow/core/tfrt/runtime/runtime.h"
26 #include "tensorflow_serving/config/model_server_config.pb.h"
27 #include "tensorflow_serving/core/availability_preserving_policy.h"
28 #include "tensorflow_serving/model_servers/model_platform_types.h"
29 #include "tensorflow_serving/model_servers/platform_config_util.h"
30 #include "tensorflow_serving/model_servers/server_core.h"
31 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
34 
35 namespace tensorflow {
36 namespace serving {
37 namespace {
38 
39 constexpr char kTestModelName[] = "test_model";
40 
41 // Test fixture for ClassificationService related tests sets up a ServerCore
42 // pointing to the half_plus_two SavedModel.
43 class TFRTClassificationServiceTest : public ::testing::Test {
44  public:
45  static void SetUpTestSuite() {
46  tfrt_stub::SetGlobalRuntime(
47  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
48 
49  ModelServerConfig config;
50  auto model_config = config.mutable_model_config_list()->add_config();
51  model_config->set_name(kTestModelName);
52  model_config->set_base_path(
53  test_util::TestSrcDirPath("servables/tensorflow/"
54  "testdata/saved_model_half_plus_two_cpu"));
55  model_config->set_model_platform(kTensorFlowModelPlatform);
56 
57  // For ServerCore Options, we leave servable_state_monitor_creator
58  // unspecified so the default servable_state_monitor_creator will be used.
59  ServerCore::Options options;
60  options.model_server_config = config;
61  PlatformConfigMap platform_config_map;
62  ::google::protobuf::Any source_adapter_config;
63  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
64  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
65  (*(*platform_config_map
66  .mutable_platform_configs())[kTensorFlowModelPlatform]
67  .mutable_source_adapter_config()) = source_adapter_config;
68  options.platform_config_map = platform_config_map;
69  options.aspired_version_policy =
70  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
71  // Reduce the number of initial load threads to be num_load_threads to avoid
72  // timing out in tests.
73  options.num_initial_load_threads = options.num_load_threads;
74  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
75  }
76 
77  static void TearDownTestSuite() { server_core_ = nullptr; }
78 
79  protected:
80  static std::unique_ptr<ServerCore> server_core_;
81  Servable::RunOptions run_options_;
82 };
83 
84 std::unique_ptr<ServerCore> TFRTClassificationServiceTest::server_core_;
85 
86 // Verifies that Classify() returns an error for different cases of an invalid
87 // ClassificationRequest.model_spec.
88 TEST_F(TFRTClassificationServiceTest, InvalidModelSpec) {
89  ClassificationRequest request;
90  ClassificationResponse response;
91 
92  // No model_spec specified.
93  EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
94  run_options_, server_core_.get(), request, &response)
95  .code(),
96  absl::StatusCode::kInvalidArgument);
97 
98  // No model name specified.
99  auto* model_spec = request.mutable_model_spec();
100  EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
101  run_options_, server_core_.get(), request, &response)
102  .code(),
103  absl::StatusCode::kInvalidArgument);
104 
105  // No servable found for model name "foo".
106  model_spec->set_name("foo");
107  EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
108  run_options_, server_core_.get(), request, &response)
109  .code(),
110  tensorflow::error::NOT_FOUND);
111 }
112 
113 // Verifies that Classify() returns an error for an invalid signature_name in
114 // ClassificationRequests's model_spec.
115 TEST_F(TFRTClassificationServiceTest, InvalidSignature) {
116  auto request = test_util::CreateProto<ClassificationRequest>(
117  "model_spec {"
118  " name: \"test_model\""
119  " signature_name: \"invalid_signature_name\""
120  "}");
121  ClassificationResponse response;
122  EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
123  run_options_, server_core_.get(), request, &response)
124  .code(),
125  tensorflow::error::FAILED_PRECONDITION);
126 }
127 
128 // Verifies that Classify() returns the correct score for a valid
129 // ClassificationRequest against the half_plus_two SavedModel's classify_x_to_y
130 // signature.
131 TEST_F(TFRTClassificationServiceTest, ClassificationSuccess) {
132  auto request = test_util::CreateProto<ClassificationRequest>(
133  "model_spec {"
134  " name: \"test_model\""
135  " signature_name: \"classify_x_to_y\""
136  "}"
137  "input {"
138  " example_list {"
139  " examples {"
140  " features {"
141  " feature: {"
142  " key : \"x\""
143  " value: {"
144  " float_list: {"
145  " value: [ 80.0 ]"
146  " }"
147  " }"
148  " }"
149  " feature: {"
150  " key : \"locale\""
151  " value: {"
152  " bytes_list: {"
153  " value: [ \"pt_BR\" ]"
154  " }"
155  " }"
156  " }"
157  " feature: {"
158  " key : \"age\""
159  " value: {"
160  " float_list: {"
161  " value: [ 19.0 ]"
162  " }"
163  " }"
164  " }"
165  " }"
166  " }"
167  " }"
168  "}");
169  ClassificationResponse response;
170  TF_EXPECT_OK(TFRTClassificationServiceImpl::Classify(
171  run_options_, server_core_.get(), request, &response));
172  EXPECT_THAT(response,
173  test_util::EqualsProto(
174  "result { classifications { classes { score: 42 } } }"
175  "model_spec {"
176  " name: \"test_model\""
177  " signature_name: \"classify_x_to_y\""
178  " version { value: 123 }"
179  "}"));
180 }
181 
182 // Verifies that ClassifyWithModelSpec() uses the model spec override rather
183 // than the one in the request.
184 TEST_F(TFRTClassificationServiceTest, ModelSpecOverride) {
185  auto request = test_util::CreateProto<ClassificationRequest>(
186  "model_spec {"
187  " name: \"test_model\""
188  "}");
189  auto model_spec_override =
190  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
191 
192  ClassificationResponse response;
193  EXPECT_NE(tensorflow::error::NOT_FOUND,
194  TFRTClassificationServiceImpl::Classify(
195  run_options_, server_core_.get(), request, &response)
196  .code());
197  EXPECT_EQ(tensorflow::error::NOT_FOUND,
198  TFRTClassificationServiceImpl::ClassifyWithModelSpec(
199  run_options_, server_core_.get(), model_spec_override, request,
200  &response)
201  .code());
202 }
203 
204 } // namespace
205 } // namespace serving
206 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231