TensorFlow Serving C++ API Documentation
tfrt_multi_inference_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
17 
18 #include "absl/status/status.h"
19 #include "tensorflow/cc/saved_model/loader.h"
20 #include "tensorflow/cc/saved_model/signature_constants.h"
21 #include "xla/tsl/lib/core/status_test_util.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
26 #include "tensorflow_serving/apis/classification.pb.h"
27 #include "tensorflow_serving/apis/input.pb.h"
28 #include "tensorflow_serving/apis/regression.pb.h"
29 #include "tensorflow_serving/core/availability_preserving_policy.h"
30 #include "tensorflow_serving/core/servable_handle.h"
31 #include "tensorflow_serving/model_servers/model_platform_types.h"
32 #include "tensorflow_serving/model_servers/platform_config_util.h"
33 #include "tensorflow_serving/model_servers/server_core.h"
34 #include "tensorflow_serving/servables/tensorflow/servable.h"
35 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
36 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
38 #include "tensorflow_serving/test_util/test_util.h"
39 
40 namespace tensorflow {
41 namespace serving {
42 namespace {
43 
44 constexpr char kTestModelName[] = "test_model";
45 constexpr int kTestModelVersion = 123;
46 
47 class TfrtMultiInferenceTest : public ::testing::Test {
48  public:
49  static void SetUpTestSuite() {
50  tfrt_stub::SetGlobalRuntime(
51  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
52  CreateServerCore(&server_core_);
53  }
54 
55  static void TearDownTestSuite() { server_core_.reset(); }
56 
57  protected:
58  static void CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
59  ModelServerConfig config;
60  auto model_config = config.mutable_model_config_list()->add_config();
61  model_config->set_name(kTestModelName);
62  model_config->set_base_path(
63  test_util::TestSrcDirPath("servables/tensorflow/"
64  "testdata/saved_model_half_plus_two_cpu"));
65  model_config->set_model_platform(kTensorFlowModelPlatform);
66 
67  // For ServerCore Options, we leave servable_state_monitor_creator
68  // unspecified so the default servable_state_monitor_creator will be used.
69  ServerCore::Options options;
70  options.model_server_config = config;
71  PlatformConfigMap platform_config_map;
72  ::google::protobuf::Any source_adapter_config;
73  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
74  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
75  (*(*platform_config_map
76  .mutable_platform_configs())[kTensorFlowModelPlatform]
77  .mutable_source_adapter_config()) = source_adapter_config;
78  options.platform_config_map = platform_config_map;
79  options.aspired_version_policy =
80  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
81  // Reduce the number of initial load threads to be num_load_threads to avoid
82  // timing out in tests.
83  options.num_initial_load_threads = options.num_load_threads;
84  TF_ASSERT_OK(ServerCore::Create(std::move(options), &server_core_));
85  }
86 
87  ServerCore* GetServerCore() { return server_core_.get(); }
88 
89  Status GetServableHandle(ServableHandle<Servable>* servable) {
90  ModelSpec model_spec;
91  model_spec.set_name(kTestModelName);
92  return GetServerCore()->GetServableHandle(model_spec, servable);
93  }
94 
95  const int64_t servable_version_ = kTestModelVersion;
96 
97  private:
98  static std::unique_ptr<ServerCore> server_core_;
99 };
100 
101 std::unique_ptr<ServerCore> TfrtMultiInferenceTest::server_core_;
102 
104 // Test Helpers
105 
106 void AddInput(const std::vector<std::pair<string, float>>& feature_kv,
107  MultiInferenceRequest* request) {
108  auto* example =
109  request->mutable_input()->mutable_example_list()->add_examples();
110  auto* features = example->mutable_features()->mutable_feature();
111  for (const auto& feature : feature_kv) {
112  (*features)[feature.first].mutable_float_list()->add_value(feature.second);
113  }
114 }
115 
116 void PopulateTask(const string& signature_name, const string& method_name,
117  int64_t version, InferenceTask* task) {
118  ModelSpec model_spec;
119  model_spec.set_name(kTestModelName);
120  if (version > 0) {
121  model_spec.mutable_version()->set_value(version);
122  }
123  model_spec.set_signature_name(signature_name);
124  *task->mutable_model_spec() = model_spec;
125  task->set_method_name(method_name);
126 }
127 
128 void ExpectStatusError(const Status& status,
129  const absl::StatusCode expected_code,
130  const string& message_substring) {
131  ASSERT_EQ(expected_code, status.code());
132  EXPECT_THAT(status.message(), ::testing::HasSubstr(message_substring));
133 }
134 
136 // Tests
137 
138 TEST_F(TfrtMultiInferenceTest, MissingInputTest) {
139  MultiInferenceRequest request;
140  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
141 
142  MultiInferenceResponse response;
143 
144  ServableHandle<Servable> servable;
145  TF_ASSERT_OK(GetServableHandle(&servable));
146  ExpectStatusError(
147  RunMultiInference(
148  tfrt::SavedModel::RunOptions(), servable_version_,
149  &(down_cast<TfrtSavedModelServable*>(servable.get())) // NOLINT
150  ->saved_model(),
151  request, &response),
152  absl::StatusCode::kInvalidArgument, "Input is empty");
153 }
154 
155 TEST_F(TfrtMultiInferenceTest, UndefinedSignatureTest) {
156  MultiInferenceRequest request;
157  AddInput({{"x", 2}}, &request);
158  PopulateTask("ThisSignatureDoesNotExist", kRegressMethodName, -1,
159  request.add_tasks());
160 
161  MultiInferenceResponse response;
162  ServableHandle<Servable> servable;
163  TF_ASSERT_OK(GetServableHandle(&servable));
164  ExpectStatusError(servable->MultiInference({}, request, &response),
165  absl::StatusCode::kInvalidArgument, "not found");
166 }
167 
168 // Two ModelSpecs, accessing different models.
169 TEST_F(TfrtMultiInferenceTest, InconsistentModelSpecsInRequestTest) {
170  MultiInferenceRequest request;
171  AddInput({{"x", 2}}, &request);
172  // Valid signature.
173  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
174 
175  // Add invalid Task to request.
176  ModelSpec model_spec;
177  model_spec.set_name("ModelDoesNotExist");
178  model_spec.set_signature_name("regress_x_to_y");
179  auto* task = request.add_tasks();
180  *task->mutable_model_spec() = model_spec;
181  task->set_method_name(kRegressMethodName);
182 
183  MultiInferenceResponse response;
184  ServableHandle<Servable> servable;
185  TF_ASSERT_OK(GetServableHandle(&servable));
186  ExpectStatusError(servable->MultiInference({}, request, &response),
187  absl::StatusCode::kInvalidArgument,
188  "must access the same model name");
189 }
190 
191 TEST_F(TfrtMultiInferenceTest, EvaluateDuplicateFunctionsTest) {
192  MultiInferenceRequest request;
193  AddInput({{"x", 2}}, &request);
194  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
195  // Add the same task again (error).
196  PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
197 
198  MultiInferenceResponse response;
199  ServableHandle<Servable> servable;
200  TF_ASSERT_OK(GetServableHandle(&servable));
201  ExpectStatusError(servable->MultiInference({}, request, &response),
202  absl::StatusCode::kInvalidArgument,
203  "Duplicate evaluation of signature: regress_x_to_y");
204 }
205 
206 TEST_F(TfrtMultiInferenceTest, UsupportedSignatureTypeTest) {
207  MultiInferenceRequest request;
208  AddInput({{"x", 2}}, &request);
209  PopulateTask("serving_default", kPredictMethodName, -1, request.add_tasks());
210 
211  MultiInferenceResponse response;
212  ServableHandle<Servable> servable;
213  TF_ASSERT_OK(GetServableHandle(&servable));
214  ExpectStatusError(servable->MultiInference({}, request, &response),
215  absl::StatusCode::kUnimplemented, "Unsupported signature");
216 }
217 
218 TEST_F(TfrtMultiInferenceTest, ValidSingleSignatureTest) {
219  MultiInferenceRequest request;
220  AddInput({{"x", 2}}, &request);
221  PopulateTask("regress_x_to_y", kRegressMethodName, servable_version_,
222  request.add_tasks());
223 
224  MultiInferenceResponse expected_response;
225  auto* inference_result = expected_response.add_results();
226  auto* model_spec = inference_result->mutable_model_spec();
227  *model_spec = request.tasks(0).model_spec();
228  model_spec->mutable_version()->set_value(servable_version_);
229  auto* regression_result = inference_result->mutable_regression_result();
230  regression_result->add_regressions()->set_value(3.0);
231 
232  MultiInferenceResponse response;
233  ServableHandle<Servable> servable;
234  TF_ASSERT_OK(GetServableHandle(&servable));
235  TF_ASSERT_OK(servable->MultiInference({}, request, &response));
236  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
237 }
238 
239 TEST_F(TfrtMultiInferenceTest, MultipleValidRegressSignaturesTest) {
240  MultiInferenceRequest request;
241  AddInput({{"x", 2}}, &request);
242  PopulateTask("regress_x_to_y", kRegressMethodName, servable_version_,
243  request.add_tasks());
244  PopulateTask("regress_x_to_y2", kRegressMethodName, servable_version_,
245  request.add_tasks());
246 
247  MultiInferenceResponse expected_response;
248 
249  // regress_x_to_y is y = 0.5x + 2.
250  auto* inference_result_1 = expected_response.add_results();
251  auto* model_spec_1 = inference_result_1->mutable_model_spec();
252  *model_spec_1 = request.tasks(0).model_spec();
253  model_spec_1->mutable_version()->set_value(servable_version_);
254  auto* regression_result_1 = inference_result_1->mutable_regression_result();
255  regression_result_1->add_regressions()->set_value(3.0);
256 
257  // regress_x_to_y2 is y2 = 0.5x + 3.
258  auto* inference_result_2 = expected_response.add_results();
259  auto* model_spec_2 = inference_result_2->mutable_model_spec();
260  *model_spec_2 = request.tasks(1).model_spec();
261  model_spec_2->mutable_version()->set_value(servable_version_);
262  auto* regression_result_2 = inference_result_2->mutable_regression_result();
263  regression_result_2->add_regressions()->set_value(4.0);
264 
265  MultiInferenceResponse response;
266  ServableHandle<Servable> servable;
267  TF_ASSERT_OK(GetServableHandle(&servable));
268  TF_ASSERT_OK(servable->MultiInference({}, request, &response));
269  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
270 }
271 
272 TEST_F(TfrtMultiInferenceTest, RegressAndClassifySignaturesTest) {
273  MultiInferenceRequest request;
274  AddInput({{"x", 2}}, &request);
275  PopulateTask("regress_x_to_y", kRegressMethodName, servable_version_,
276  request.add_tasks());
277  PopulateTask("classify_x_to_y", kClassifyMethodName, servable_version_,
278  request.add_tasks());
279 
280  MultiInferenceResponse expected_response;
281  auto* inference_result_1 = expected_response.add_results();
282  auto* model_spec_1 = inference_result_1->mutable_model_spec();
283  *model_spec_1 = request.tasks(0).model_spec();
284  model_spec_1->mutable_version()->set_value(servable_version_);
285  auto* regression_result = inference_result_1->mutable_regression_result();
286  regression_result->add_regressions()->set_value(3.0);
287 
288  auto* inference_result_2 = expected_response.add_results();
289  auto* model_spec_2 = inference_result_2->mutable_model_spec();
290  *model_spec_2 = request.tasks(1).model_spec();
291  model_spec_2->mutable_version()->set_value(servable_version_);
292  auto* classification_result =
293  inference_result_2->mutable_classification_result();
294  classification_result->add_classifications()->add_classes()->set_score(3.0);
295 
296  MultiInferenceResponse response;
297  ServableHandle<Servable> servable;
298  TF_ASSERT_OK(GetServableHandle(&servable));
299  TF_ASSERT_OK(servable->MultiInference({}, request, &response));
300  EXPECT_THAT(response, test_util::EqualsProto(expected_response));
301 }
302 
303 } // namespace
304 } // namespace serving
305 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231