TensorFlow Serving C++ API Documentation
classification_service_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/threadpool_options.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow_serving/config/model_server_config.pb.h"
28 #include "tensorflow_serving/core/availability_preserving_policy.h"
29 #include "tensorflow_serving/model_servers/model_platform_types.h"
30 #include "tensorflow_serving/model_servers/platform_config_util.h"
31 #include "tensorflow_serving/model_servers/server_core.h"
32 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
34 
35 namespace tensorflow {
36 namespace serving {
37 namespace {
38 
39 constexpr char kTestModelName[] = "test_model";
40 
41 // Test fixture for ClassificationService related tests sets up a ServerCore
42 // pointing to the half_plus_two SavedModel.
43 class ClassificationServiceTest : public ::testing::Test {
44  public:
45  static void SetUpTestSuite() {
46  ModelServerConfig config;
47  auto model_config = config.mutable_model_config_list()->add_config();
48  model_config->set_name(kTestModelName);
49  model_config->set_base_path(test_util::TensorflowTestSrcDirPath(
50  "cc/saved_model/testdata/half_plus_two"));
51  model_config->set_model_platform(kTensorFlowModelPlatform);
52 
53  // For ServerCore Options, we leave servable_state_monitor_creator
54  // unspecified so the default servable_state_monitor_creator will be used.
55  ServerCore::Options options;
56  options.model_server_config = config;
57  options.platform_config_map =
58  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
59  options.aspired_version_policy =
60  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
61  // Reduce the number of initial load threads to be num_load_threads to avoid
62  // timing out in tests.
63  options.num_initial_load_threads = options.num_load_threads;
64  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
65  }
66 
67  static void TearDownTestSuite() { server_core_ = nullptr; }
68 
69  protected:
70  static std::unique_ptr<ServerCore> server_core_;
71 };
72 
73 std::unique_ptr<ServerCore> ClassificationServiceTest::server_core_;
74 
75 // Verifies that Classify() returns an error for different cases of an invalid
76 // ClassificationRequest.model_spec.
77 TEST_F(ClassificationServiceTest, InvalidModelSpec) {
78  ClassificationRequest request;
79  ClassificationResponse response;
80 
81  // No model_spec specified.
82  EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
83  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
84  request, &response)
85  .code(),
86  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
87 
88  // No model name specified.
89  auto* model_spec = request.mutable_model_spec();
90  EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
91  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
92  request, &response)
93  .code(),
94  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
95 
96  // No servable found for model name "foo".
97  model_spec->set_name("foo");
98  EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
99  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
100  request, &response)
101  .code(),
102  tensorflow::error::NOT_FOUND);
103 }
104 
105 // Verifies that Classify() returns an error for an invalid signature_name in
106 // ClassificationRequests's model_spec.
107 TEST_F(ClassificationServiceTest, InvalidSignature) {
108  auto request = test_util::CreateProto<ClassificationRequest>(
109  "model_spec {"
110  " name: \"test_model\""
111  " signature_name: \"invalid_signature_name\""
112  "}");
113  ClassificationResponse response;
114  EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
115  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
116  request, &response)
117  .code(),
118  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
119 }
120 
121 // Verifies that Classify() returns the correct score for a valid
122 // ClassificationRequest against the half_plus_two SavedModel's classify_x_to_y
123 // signature.
124 TEST_F(ClassificationServiceTest, ClassificationSuccess) {
125  auto request = test_util::CreateProto<ClassificationRequest>(
126  "model_spec {"
127  " name: \"test_model\""
128  " signature_name: \"classify_x_to_y\""
129  "}"
130  "input {"
131  " example_list {"
132  " examples {"
133  " features {"
134  " feature: {"
135  " key : \"x\""
136  " value: {"
137  " float_list: {"
138  " value: [ 80.0 ]"
139  " }"
140  " }"
141  " }"
142  " feature: {"
143  " key : \"locale\""
144  " value: {"
145  " bytes_list: {"
146  " value: [ \"pt_BR\" ]"
147  " }"
148  " }"
149  " }"
150  " feature: {"
151  " key : \"age\""
152  " value: {"
153  " float_list: {"
154  " value: [ 19.0 ]"
155  " }"
156  " }"
157  " }"
158  " }"
159  " }"
160  " }"
161  "}");
162  ClassificationResponse response;
163  TF_EXPECT_OK(TensorflowClassificationServiceImpl::Classify(
164  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(), request,
165  &response));
166  EXPECT_THAT(response,
167  test_util::EqualsProto(
168  "result { classifications { classes { score: 42 } } }"
169  "model_spec {"
170  " name: \"test_model\""
171  " signature_name: \"classify_x_to_y\""
172  " version { value: 123 }"
173  "}"));
174 }
175 
176 // Verifies that ClassifyWithModelSpec() uses the model spec override rather
177 // than the one in the request.
178 TEST_F(ClassificationServiceTest, ModelSpecOverride) {
179  auto request = test_util::CreateProto<ClassificationRequest>(
180  "model_spec {"
181  " name: \"test_model\""
182  "}");
183  auto model_spec_override =
184  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
185 
186  ClassificationResponse response;
187  EXPECT_NE(tensorflow::error::NOT_FOUND,
188  TensorflowClassificationServiceImpl::Classify(
189  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
190  request, &response)
191  .code());
192  EXPECT_EQ(tensorflow::error::NOT_FOUND,
193  TensorflowClassificationServiceImpl::ClassifyWithModelSpec(
194  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
195  model_spec_override, request, &response)
196  .code());
197 }
198 
199 TEST_F(ClassificationServiceTest, ThreadPoolOptions) {
200  auto request = test_util::CreateProto<ClassificationRequest>(
201  "model_spec {"
202  " name: \"test_model\""
203  " signature_name: \"classify_x_to_y\""
204  "}"
205  "input {"
206  " example_list {"
207  " examples {"
208  " features {"
209  " feature: {"
210  " key : \"x\""
211  " value: {"
212  " float_list: {"
213  " value: [ 80.0 ]"
214  " }"
215  " }"
216  " }"
217  " feature: {"
218  " key : \"locale\""
219  " value: {"
220  " bytes_list: {"
221  " value: [ \"pt_BR\" ]"
222  " }"
223  " }"
224  " }"
225  " feature: {"
226  " key : \"age\""
227  " value: {"
228  " float_list: {"
229  " value: [ 19.0 ]"
230  " }"
231  " }"
232  " }"
233  " }"
234  " }"
235  " }"
236  "}");
237 
238  test_util::CountingThreadPool inter_op_threadpool(Env::Default(), "InterOp",
239  /*num_threads=*/1);
240  test_util::CountingThreadPool intra_op_threadpool(Env::Default(), "IntraOp",
241  /*num_threads=*/1);
242  thread::ThreadPoolOptions thread_pool_options;
243  thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
244  thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
245  ClassificationResponse response;
246  TF_EXPECT_OK(TensorflowClassificationServiceImpl::Classify(
247  RunOptions(), server_core_.get(), thread_pool_options, request,
248  &response));
249  EXPECT_THAT(response,
250  test_util::EqualsProto(
251  "result { classifications { classes { score: 42 } } }"
252  "model_spec {"
253  " name: \"test_model\""
254  " signature_name: \"classify_x_to_y\""
255  " version { value: 123 }"
256  "}"));
257 
258  // The intra_op_threadpool doesn't have anything scheduled.
259  ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
260 }
261 
262 } // namespace
263 } // namespace serving
264 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231