TensorFlow Serving C++ API Documentation
tfrt_regression_service_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_regression_service.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
25 #include "tensorflow_serving/config/model_server_config.pb.h"
26 #include "tensorflow_serving/core/availability_preserving_policy.h"
27 #include "tensorflow_serving/model_servers/model_platform_types.h"
28 #include "tensorflow_serving/model_servers/platform_config_util.h"
29 #include "tensorflow_serving/model_servers/server_core.h"
30 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
31 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
32 #include "tensorflow_serving/test_util/test_util.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 namespace {
37 
38 constexpr char kTestModelName[] = "test_model";
39 
40 // Test fixture for TFRTRegressionService related tests sets up a ServerCore
41 // pointing to the half_plus_two SavedModel.
42 class TFRTRegressionServiceTest : public ::testing::Test {
43  public:
44  static void SetUpTestSuite() {
45  tfrt_stub::SetGlobalRuntime(
46  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
47 
48  ModelServerConfig config;
49  auto model_config = config.mutable_model_config_list()->add_config();
50  model_config->set_name(kTestModelName);
51  model_config->set_base_path(
52  test_util::TestSrcDirPath("servables/tensorflow/"
53  "testdata/saved_model_half_plus_two_cpu"));
54  model_config->set_model_platform(kTensorFlowModelPlatform);
55 
56  // For ServerCore Options, we leave servable_state_monitor_creator
57  // unspecified so the default servable_state_monitor_creator will be used.
58  ServerCore::Options options;
59  options.model_server_config = config;
60  PlatformConfigMap platform_config_map;
61  ::google::protobuf::Any source_adapter_config;
62  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
63  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
64  (*(*platform_config_map
65  .mutable_platform_configs())[kTensorFlowModelPlatform]
66  .mutable_source_adapter_config()) = source_adapter_config;
67  options.platform_config_map = platform_config_map;
68  options.aspired_version_policy =
69  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
70  // Reduce the number of initial load threads to be num_load_threads to avoid
71  // timing out in tests.
72  options.num_initial_load_threads = options.num_load_threads;
73  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
74  }
75 
76  static void TearDownTestSuite() { server_core_ = nullptr; }
77 
78  protected:
79  static std::unique_ptr<ServerCore> server_core_;
80  Servable::RunOptions run_options_;
81 };
82 
83 std::unique_ptr<ServerCore> TFRTRegressionServiceTest::server_core_;
84 
85 // Verifies that Regress() returns an error for different cases of an invalid
86 // RegressionRequest.model_spec.
87 TEST_F(TFRTRegressionServiceTest, InvalidModelSpec) {
88  RegressionRequest request;
89  RegressionResponse response;
90 
91  // No model_spec specified.
92  EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
93  request, &response)
94  .code(),
95  absl::StatusCode::kInvalidArgument);
96 
97  // No model name specified.
98  auto* model_spec = request.mutable_model_spec();
99  EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
100  request, &response)
101  .code(),
102  absl::StatusCode::kInvalidArgument);
103 
104  // No servable found for model name "foo".
105  model_spec->set_name("foo");
106  EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
107  request, &response)
108  .code(),
109  tensorflow::error::NOT_FOUND);
110 }
111 
112 // Verifies that Regress() returns an error for an invalid signature_name in
113 // RegressionRequests's model_spec.
114 TEST_F(TFRTRegressionServiceTest, InvalidSignature) {
115  auto request = test_util::CreateProto<RegressionRequest>(
116  "model_spec {"
117  " name: \"test_model\""
118  " signature_name: \"invalid_signature_name\""
119  "}");
120  RegressionResponse response;
121  EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
122  request, &response)
123  .code(),
124  tensorflow::error::FAILED_PRECONDITION);
125 }
126 
127 // Verifies that Regress() returns the correct value for a valid
128 // RegressionRequest against the half_plus_two SavedModel's regress_x_to_y
129 // signature.
130 TEST_F(TFRTRegressionServiceTest, RegressionSuccess) {
131  auto request = test_util::CreateProto<RegressionRequest>(
132  "model_spec {"
133  " name: \"test_model\""
134  " signature_name: \"regress_x_to_y\""
135  "}"
136  "input {"
137  " example_list {"
138  " examples {"
139  " features {"
140  " feature: {"
141  " key : \"x\""
142  " value: {"
143  " float_list: {"
144  " value: [ 80.0 ]"
145  " }"
146  " }"
147  " }"
148  " feature: {"
149  " key : \"locale\""
150  " value: {"
151  " bytes_list: {"
152  " value: [ \"pt_BR\" ]"
153  " }"
154  " }"
155  " }"
156  " feature: {"
157  " key : \"age\""
158  " value: {"
159  " float_list: {"
160  " value: [ 19.0 ]"
161  " }"
162  " }"
163  " }"
164  " }"
165  " }"
166  " }"
167  "}");
168  RegressionResponse response;
169  TF_EXPECT_OK(TFRTRegressionServiceImpl::Regress(
170  run_options_, server_core_.get(), request, &response));
171  EXPECT_THAT(response,
172  test_util::EqualsProto("result { regressions { value: 42 } }"
173  "model_spec {"
174  " name: \"test_model\""
175  " signature_name: \"regress_x_to_y\""
176  " version { value: 123 }"
177  "}"));
178 }
179 
180 // Verifies that RegressWithModelSpec() uses the model spec override rather than
181 // the one in the request.
182 TEST_F(TFRTRegressionServiceTest, ModelSpecOverride) {
183  auto request = test_util::CreateProto<RegressionRequest>(
184  "model_spec {"
185  " name: \"test_model\""
186  "}");
187  auto model_spec_override =
188  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
189 
190  RegressionResponse response;
191  EXPECT_NE(tensorflow::error::NOT_FOUND,
192  TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
193  request, &response)
194  .code());
195  EXPECT_EQ(tensorflow::error::NOT_FOUND,
196  TFRTRegressionServiceImpl::RegressWithModelSpec(
197  run_options_, server_core_.get(), model_spec_override, request,
198  &response)
199  .code());
200 }
201 
202 } // namespace
203 } // namespace serving
204 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231