TensorFlow Serving C++ API Documentation
regression_service_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/protobuf/config.pb.h"
26 #include "tensorflow_serving/config/model_server_config.pb.h"
27 #include "tensorflow_serving/core/availability_preserving_policy.h"
28 #include "tensorflow_serving/model_servers/model_platform_types.h"
29 #include "tensorflow_serving/model_servers/platform_config_util.h"
30 #include "tensorflow_serving/model_servers/server_core.h"
31 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
32 #include "tensorflow_serving/test_util/test_util.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 namespace {
37 
38 constexpr char kTestModelName[] = "test_model";
39 
40 // Test fixture for RegressionService related tests sets up a ServerCore
41 // pointing to the half_plus_two SavedModel.
42 class RegressionServiceTest : public ::testing::Test {
43  public:
44  static void SetUpTestSuite() {
45  ModelServerConfig config;
46  auto model_config = config.mutable_model_config_list()->add_config();
47  model_config->set_name(kTestModelName);
48  model_config->set_base_path(test_util::TensorflowTestSrcDirPath(
49  "cc/saved_model/testdata/half_plus_two"));
50  model_config->set_model_platform(kTensorFlowModelPlatform);
51 
52  // For ServerCore Options, we leave servable_state_monitor_creator
53  // unspecified so the default servable_state_monitor_creator will be used.
54  ServerCore::Options options;
55  options.model_server_config = config;
56  options.platform_config_map =
57  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
58  options.aspired_version_policy =
59  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
60  // Reduce the number of initial load threads to be num_load_threads to avoid
61  // timing out in tests.
62  options.num_initial_load_threads = options.num_load_threads;
63  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
64  }
65 
66  static void TearDownTestSuite() { server_core_ = nullptr; }
67 
68  protected:
69  static std::unique_ptr<ServerCore> server_core_;
70 };
71 
72 std::unique_ptr<ServerCore> RegressionServiceTest::server_core_;
73 
74 // Verifies that Regress() returns an error for different cases of an invalid
75 // RegressionRequest.model_spec.
76 TEST_F(RegressionServiceTest, InvalidModelSpec) {
77  RegressionRequest request;
78  RegressionResponse response;
79 
80  // No model_spec specified.
81  EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
82  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
83  request, &response)
84  .code(),
85  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
86 
87  // No model name specified.
88  auto* model_spec = request.mutable_model_spec();
89  EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
90  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
91  request, &response)
92  .code(),
93  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
94 
95  // No servable found for model name "foo".
96  model_spec->set_name("foo");
97  EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
98  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
99  request, &response)
100  .code(),
101  tensorflow::error::NOT_FOUND);
102 }
103 
104 // Verifies that Regress() returns an error for an invalid signature_name in
105 // RegressionRequests's model_spec.
106 TEST_F(RegressionServiceTest, InvalidSignature) {
107  auto request = test_util::CreateProto<RegressionRequest>(
108  "model_spec {"
109  " name: \"test_model\""
110  " signature_name: \"invalid_signature_name\""
111  "}");
112  RegressionResponse response;
113  EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
114  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
115  request, &response)
116  .code(),
117  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument));
118 }
119 
120 // Verifies that Regress() returns the correct value for a valid
121 // RegressionRequest against the half_plus_two SavedModel's regress_x_to_y
122 // signature.
123 TEST_F(RegressionServiceTest, RegressionSuccess) {
124  auto request = test_util::CreateProto<RegressionRequest>(
125  "model_spec {"
126  " name: \"test_model\""
127  " signature_name: \"regress_x_to_y\""
128  "}"
129  "input {"
130  " example_list {"
131  " examples {"
132  " features {"
133  " feature: {"
134  " key : \"x\""
135  " value: {"
136  " float_list: {"
137  " value: [ 80.0 ]"
138  " }"
139  " }"
140  " }"
141  " feature: {"
142  " key : \"locale\""
143  " value: {"
144  " bytes_list: {"
145  " value: [ \"pt_BR\" ]"
146  " }"
147  " }"
148  " }"
149  " feature: {"
150  " key : \"age\""
151  " value: {"
152  " float_list: {"
153  " value: [ 19.0 ]"
154  " }"
155  " }"
156  " }"
157  " }"
158  " }"
159  " }"
160  "}");
161  RegressionResponse response;
162  TF_EXPECT_OK(TensorflowRegressionServiceImpl::Regress(
163  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(), request,
164  &response));
165  EXPECT_THAT(response,
166  test_util::EqualsProto("result { regressions { value: 42 } }"
167  "model_spec {"
168  " name: \"test_model\""
169  " signature_name: \"regress_x_to_y\""
170  " version { value: 123 }"
171  "}"));
172 }
173 
174 // Verifies that RegressWithModelSpec() uses the model spec override rather than
175 // the one in the request.
176 TEST_F(RegressionServiceTest, ModelSpecOverride) {
177  auto request = test_util::CreateProto<RegressionRequest>(
178  "model_spec {"
179  " name: \"test_model\""
180  "}");
181  auto model_spec_override =
182  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
183 
184  RegressionResponse response;
185  EXPECT_NE(tensorflow::error::NOT_FOUND,
186  TensorflowRegressionServiceImpl::Regress(
187  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
188  request, &response)
189  .code());
190  EXPECT_EQ(tensorflow::error::NOT_FOUND,
191  TensorflowRegressionServiceImpl::RegressWithModelSpec(
192  RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
193  model_spec_override, request, &response)
194  .code());
195 }
196 
197 TEST_F(RegressionServiceTest, ThreadPoolOptions) {
198  auto request = test_util::CreateProto<RegressionRequest>(
199  "model_spec {"
200  " name: \"test_model\""
201  " signature_name: \"regress_x_to_y\""
202  "}"
203  "input {"
204  " example_list {"
205  " examples {"
206  " features {"
207  " feature: {"
208  " key : \"x\""
209  " value: {"
210  " float_list: {"
211  " value: [ 80.0 ]"
212  " }"
213  " }"
214  " }"
215  " feature: {"
216  " key : \"locale\""
217  " value: {"
218  " bytes_list: {"
219  " value: [ \"pt_BR\" ]"
220  " }"
221  " }"
222  " }"
223  " feature: {"
224  " key : \"age\""
225  " value: {"
226  " float_list: {"
227  " value: [ 19.0 ]"
228  " }"
229  " }"
230  " }"
231  " }"
232  " }"
233  " }"
234  "}");
235 
236  test_util::CountingThreadPool inter_op_threadpool(Env::Default(), "InterOp",
237  /*num_threads=*/1);
238  test_util::CountingThreadPool intra_op_threadpool(Env::Default(), "IntraOp",
239  /*num_threads=*/1);
240  thread::ThreadPoolOptions thread_pool_options;
241  thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
242  thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
243  RegressionResponse response;
244  TF_EXPECT_OK(TensorflowRegressionServiceImpl::Regress(
245  RunOptions(), server_core_.get(), thread_pool_options, request,
246  &response));
247  EXPECT_THAT(response,
248  test_util::EqualsProto("result { regressions { value: 42 } }"
249  "model_spec {"
250  " name: \"test_model\""
251  " signature_name: \"regress_x_to_y\""
252  " version { value: 123 }"
253  "}"));
254  // The intra_op_threadpool doesn't have anything scheduled.
255  ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
256 }
257 
258 } // namespace
259 } // namespace serving
260 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231