16 #include "tensorflow_serving/servables/tensorflow/multi_inference.h"
19 #include <type_traits>
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/example/feature.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow_serving/apis/classification.pb.h"
31 #include "tensorflow_serving/apis/input.pb.h"
32 #include "tensorflow_serving/apis/regression.pb.h"
33 #include "tensorflow_serving/core/availability_preserving_policy.h"
34 #include "tensorflow_serving/model_servers/model_platform_types.h"
35 #include "tensorflow_serving/model_servers/platform_config_util.h"
36 #include "tensorflow_serving/model_servers/server_core.h"
37 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/util.h"
39 #include "tensorflow_serving/test_util/test_util.h"
41 namespace tensorflow {
45 constexpr
char kTestModelName[] =
"test_model";
49 typedef std::integral_constant<int, 1> tf1_model_t;
50 typedef std::integral_constant<int, 2> tf2_model_t;
53 class MultiInferenceTest :
public ::testing::Test {
55 static void SetUpTestSuite() {
56 SetSignatureMethodNameCheckFeature(UseTf1Model());
57 TF_ASSERT_OK(CreateServerCore(&server_core_));
60 static void TearDownTestSuite() { server_core_.reset(); }
63 static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
64 ModelServerConfig config;
65 auto model_config = config.mutable_model_config_list()->add_config();
66 model_config->set_name(kTestModelName);
67 const auto& tf1_saved_model = test_util::TensorflowTestSrcDirPath(
68 "cc/saved_model/testdata/half_plus_two");
69 const auto& tf2_saved_model = test_util::TestSrcDirPath(
70 "/servables/tensorflow/testdata/saved_model_half_plus_two_tf2_cpu");
71 model_config->set_base_path(UseTf1Model() ? tf1_saved_model
73 model_config->set_model_platform(kTensorFlowModelPlatform);
77 ServerCore::Options options;
78 options.model_server_config = config;
79 options.platform_config_map =
80 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
83 options.num_initial_load_threads = options.num_load_threads;
84 options.aspired_version_policy =
85 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
89 static bool UseTf1Model() {
return std::is_same<T, tf1_model_t>::value; }
91 ServerCore* GetServerCore() {
return this->server_core_.get(); }
93 Status GetInferenceRunner(
94 std::unique_ptr<TensorFlowMultiInferenceRunner>* inference_runner) {
95 ServableHandle<SavedModelBundle> bundle;
97 model_spec.set_name(kTestModelName);
98 TF_RETURN_IF_ERROR(GetServerCore()->GetServableHandle(model_spec, &bundle));
100 inference_runner->reset(
new TensorFlowMultiInferenceRunner(
101 bundle->session.get(), &bundle->meta_graph_def,
102 {this->servable_version_}));
103 return absl::OkStatus();
106 Status GetServableHandle(ServableHandle<SavedModelBundle>* bundle) {
107 ModelSpec model_spec;
108 model_spec.set_name(kTestModelName);
109 return GetServerCore()->GetServableHandle(model_spec, bundle);
112 const int64_t servable_version_ = 1;
115 static std::unique_ptr<ServerCore> server_core_;
118 template <
typename T>
119 std::unique_ptr<ServerCore> MultiInferenceTest<T>::server_core_;
121 TYPED_TEST_SUITE_P(MultiInferenceTest);
126 void AddInput(
const std::vector<std::pair<string, float>>& feature_kv,
127 MultiInferenceRequest* request) {
129 request->mutable_input()->mutable_example_list()->add_examples();
130 auto* features = example->mutable_features()->mutable_feature();
131 for (
const auto& feature : feature_kv) {
132 (*features)[feature.first].mutable_float_list()->add_value(feature.second);
136 void PopulateTask(
const string& signature_name,
const string& method_name,
137 InferenceTask* task) {
138 ModelSpec model_spec;
139 model_spec.set_name(kTestModelName);
140 model_spec.set_signature_name(signature_name);
141 *task->mutable_model_spec() = model_spec;
142 task->set_method_name(method_name);
145 void ExpectStatusError(
const Status& status,
146 const tensorflow::errors::Code expected_code,
147 const string& message_substring) {
148 EXPECT_EQ(expected_code, status.code());
149 EXPECT_THAT(status.message(), ::testing::HasSubstr(message_substring));
155 TYPED_TEST_P(MultiInferenceTest, MissingInputTest) {
156 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
157 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
159 MultiInferenceRequest request;
160 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
162 MultiInferenceResponse response;
164 inference_runner->Infer(RunOptions(), request, &response),
165 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
169 ServableHandle<SavedModelBundle> bundle;
170 TF_ASSERT_OK(this->GetServableHandle(&bundle));
172 RunMultiInference(RunOptions(), bundle->meta_graph_def,
173 this->servable_version_, bundle->session.get(), request,
175 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
179 TYPED_TEST_P(MultiInferenceTest, UndefinedSignatureTest) {
180 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
181 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
183 MultiInferenceRequest request;
184 AddInput({{
"x", 2}}, &request);
185 PopulateTask(
"ThisSignatureDoesNotExist", kRegressMethodName,
186 request.add_tasks());
188 MultiInferenceResponse response;
190 inference_runner->Infer(RunOptions(), request, &response),
191 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
192 "signature not found");
195 ServableHandle<SavedModelBundle> bundle;
196 TF_ASSERT_OK(this->GetServableHandle(&bundle));
198 RunMultiInference(RunOptions(), bundle->meta_graph_def,
199 this->servable_version_, bundle->session.get(), request,
201 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
202 "signature not found");
206 TYPED_TEST_P(MultiInferenceTest, InconsistentModelSpecsInRequestTest) {
207 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
208 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
210 MultiInferenceRequest request;
211 AddInput({{
"x", 2}}, &request);
213 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
216 ModelSpec model_spec;
217 model_spec.set_name(
"ModelDoesNotExist");
218 model_spec.set_signature_name(
"regress_x_to_y");
219 auto* task = request.add_tasks();
220 *task->mutable_model_spec() = model_spec;
221 task->set_method_name(kRegressMethodName);
223 MultiInferenceResponse response;
225 inference_runner->Infer(RunOptions(), request, &response),
226 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
227 "must access the same model name");
230 ServableHandle<SavedModelBundle> bundle;
231 TF_ASSERT_OK(this->GetServableHandle(&bundle));
233 RunMultiInference(RunOptions(), bundle->meta_graph_def,
234 this->servable_version_, bundle->session.get(), request,
236 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
237 "must access the same model name");
240 TYPED_TEST_P(MultiInferenceTest, EvaluateDuplicateSignaturesTest) {
241 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
242 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
244 MultiInferenceRequest request;
245 AddInput({{
"x", 2}}, &request);
246 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
248 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
250 MultiInferenceResponse response;
252 inference_runner->Infer(RunOptions(), request, &response),
253 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
254 "Duplicate evaluation of signature: regress_x_to_y");
257 ServableHandle<SavedModelBundle> bundle;
258 TF_ASSERT_OK(this->GetServableHandle(&bundle));
260 RunMultiInference(RunOptions(), bundle->meta_graph_def,
261 this->servable_version_, bundle->session.get(), request,
263 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
264 "Duplicate evaluation of signature: regress_x_to_y");
267 TYPED_TEST_P(MultiInferenceTest, UsupportedSignatureTypeTest) {
268 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
269 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
271 MultiInferenceRequest request;
272 AddInput({{
"x", 2}}, &request);
273 PopulateTask(
"serving_default", kPredictMethodName, request.add_tasks());
275 MultiInferenceResponse response;
277 inference_runner->Infer(RunOptions(), request, &response),
278 static_cast<absl::StatusCode
>(absl::StatusCode ::kUnimplemented),
279 "Unsupported signature");
282 ServableHandle<SavedModelBundle> bundle;
283 TF_ASSERT_OK(this->GetServableHandle(&bundle));
285 RunMultiInference(RunOptions(), bundle->meta_graph_def,
286 this->servable_version_, bundle->session.get(), request,
288 static_cast<absl::StatusCode
>(absl::StatusCode::kUnimplemented),
289 "Unsupported signature");
292 TYPED_TEST_P(MultiInferenceTest, ValidSingleSignatureTest) {
293 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
294 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
296 MultiInferenceRequest request;
297 AddInput({{
"x", 2}}, &request);
298 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
300 MultiInferenceResponse expected_response;
301 auto* inference_result = expected_response.add_results();
302 auto* model_spec = inference_result->mutable_model_spec();
303 *model_spec = request.tasks(0).model_spec();
304 model_spec->mutable_version()->set_value(this->servable_version_);
305 auto* regression_result = inference_result->mutable_regression_result();
306 regression_result->add_regressions()->set_value(3.0);
308 MultiInferenceResponse response;
309 TF_ASSERT_OK(inference_runner->Infer(RunOptions(), request, &response));
310 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
314 ServableHandle<SavedModelBundle> bundle;
315 TF_ASSERT_OK(this->GetServableHandle(&bundle));
316 TF_ASSERT_OK(RunMultiInference(RunOptions(), bundle->meta_graph_def,
317 this->servable_version_, bundle->session.get(),
318 request, &response));
319 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
322 TYPED_TEST_P(MultiInferenceTest, MultipleValidRegressSignaturesTest) {
323 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
324 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
326 MultiInferenceRequest request;
327 AddInput({{
"x", 2}}, &request);
328 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
329 PopulateTask(
"regress_x_to_y2", kRegressMethodName, request.add_tasks());
331 MultiInferenceResponse expected_response;
334 auto* inference_result_1 = expected_response.add_results();
335 auto* model_spec_1 = inference_result_1->mutable_model_spec();
336 *model_spec_1 = request.tasks(0).model_spec();
337 model_spec_1->mutable_version()->set_value(this->servable_version_);
338 auto* regression_result_1 = inference_result_1->mutable_regression_result();
339 regression_result_1->add_regressions()->set_value(3.0);
342 auto* inference_result_2 = expected_response.add_results();
343 auto* model_spec_2 = inference_result_2->mutable_model_spec();
344 *model_spec_2 = request.tasks(1).model_spec();
345 model_spec_2->mutable_version()->set_value(this->servable_version_);
346 auto* regression_result_2 = inference_result_2->mutable_regression_result();
347 regression_result_2->add_regressions()->set_value(4.0);
349 MultiInferenceResponse response;
350 TF_ASSERT_OK(inference_runner->Infer(RunOptions(), request, &response));
351 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
355 ServableHandle<SavedModelBundle> bundle;
356 TF_ASSERT_OK(this->GetServableHandle(&bundle));
357 TF_ASSERT_OK(RunMultiInference(RunOptions(), bundle->meta_graph_def,
358 this->servable_version_, bundle->session.get(),
359 request, &response));
360 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
363 TYPED_TEST_P(MultiInferenceTest, RegressAndClassifySignaturesTest) {
364 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
365 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
367 MultiInferenceRequest request;
368 AddInput({{
"x", 2}}, &request);
369 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
370 PopulateTask(
"classify_x_to_y", kClassifyMethodName, request.add_tasks());
372 MultiInferenceResponse expected_response;
373 auto* inference_result_1 = expected_response.add_results();
374 auto* model_spec_1 = inference_result_1->mutable_model_spec();
375 *model_spec_1 = request.tasks(0).model_spec();
376 model_spec_1->mutable_version()->set_value(this->servable_version_);
377 auto* regression_result = inference_result_1->mutable_regression_result();
378 regression_result->add_regressions()->set_value(3.0);
380 auto* inference_result_2 = expected_response.add_results();
381 auto* model_spec_2 = inference_result_2->mutable_model_spec();
382 *model_spec_2 = request.tasks(1).model_spec();
383 model_spec_2->mutable_version()->set_value(this->servable_version_);
384 auto* classification_result =
385 inference_result_2->mutable_classification_result();
386 classification_result->add_classifications()->add_classes()->set_score(3.0);
388 MultiInferenceResponse response;
389 TF_ASSERT_OK(inference_runner->Infer(RunOptions(), request, &response));
390 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
394 ServableHandle<SavedModelBundle> bundle;
395 TF_ASSERT_OK(this->GetServableHandle(&bundle));
396 TF_ASSERT_OK(RunMultiInference(RunOptions(), bundle->meta_graph_def,
397 this->servable_version_, bundle->session.get(),
398 request, &response));
399 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
402 TYPED_TEST_P(MultiInferenceTest, ThreadPoolOptions) {
403 std::unique_ptr<TensorFlowMultiInferenceRunner> inference_runner;
404 TF_ASSERT_OK(this->GetInferenceRunner(&inference_runner));
406 MultiInferenceRequest request;
407 AddInput({{
"x", 2}}, &request);
408 PopulateTask(
"regress_x_to_y", kRegressMethodName, request.add_tasks());
410 MultiInferenceResponse expected_response;
411 auto* inference_result = expected_response.add_results();
412 auto* model_spec = inference_result->mutable_model_spec();
413 *model_spec = request.tasks(0).model_spec();
414 model_spec->mutable_version()->set_value(this->servable_version_);
415 auto* regression_result = inference_result->mutable_regression_result();
416 regression_result->add_regressions()->set_value(3.0);
418 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
420 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
422 thread::ThreadPoolOptions thread_pool_options;
423 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
424 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
425 MultiInferenceResponse response;
426 ServableHandle<SavedModelBundle> bundle;
427 TF_ASSERT_OK(this->GetServableHandle(&bundle));
428 TF_ASSERT_OK(RunMultiInference(RunOptions(), bundle->meta_graph_def,
429 this->servable_version_, bundle->session.get(),
430 request, &response, thread_pool_options));
431 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
434 ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
437 REGISTER_TYPED_TEST_SUITE_P(
438 MultiInferenceTest, MissingInputTest, UndefinedSignatureTest,
439 InconsistentModelSpecsInRequestTest, EvaluateDuplicateSignaturesTest,
440 UsupportedSignatureTypeTest, ValidSingleSignatureTest,
441 MultipleValidRegressSignaturesTest, RegressAndClassifySignaturesTest,
444 typedef ::testing::Types<tf1_model_t, tf2_model_t> ModelTypes;
445 INSTANTIATE_TYPED_TEST_SUITE_P(MultiInference, MultiInferenceTest, ModelTypes);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)