16 #include "tensorflow_serving/servables/tensorflow/multi_inference_helper.h"
19 #include <type_traits>
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/example/feature.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow_serving/apis/classification.pb.h"
31 #include "tensorflow_serving/apis/input.pb.h"
32 #include "tensorflow_serving/apis/regression.pb.h"
33 #include "tensorflow_serving/core/availability_preserving_policy.h"
34 #include "tensorflow_serving/model_servers/model_platform_types.h"
35 #include "tensorflow_serving/model_servers/platform_config_util.h"
36 #include "tensorflow_serving/model_servers/server_core.h"
37 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/util.h"
39 #include "tensorflow_serving/test_util/test_util.h"
41 namespace tensorflow {
45 constexpr
char kTestModelName[] =
"test_model";
46 constexpr
int kTestModelVersion = 123;
50 typedef std::integral_constant<int, 1> tf1_model_t;
51 typedef std::integral_constant<int, 2> tf2_model_t;
54 class MultiInferenceTest :
public ::testing::Test {
56 static void SetUpTestSuite() {
57 SetSignatureMethodNameCheckFeature(UseTf1Model());
58 TF_ASSERT_OK(CreateServerCore(&server_core_));
61 static void TearDownTestSuite() { server_core_.reset(); }
64 static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
65 ModelServerConfig config;
66 auto model_config = config.mutable_model_config_list()->add_config();
67 model_config->set_name(kTestModelName);
68 const auto& tf1_saved_model = test_util::TensorflowTestSrcDirPath(
69 "cc/saved_model/testdata/half_plus_two");
70 const auto& tf2_saved_model = test_util::TestSrcDirPath(
71 "/servables/tensorflow/testdata/saved_model_half_plus_two_tf2_cpu");
72 model_config->set_base_path(UseTf1Model() ? tf1_saved_model
74 model_config->set_model_platform(kTensorFlowModelPlatform);
78 ServerCore::Options options;
79 options.model_server_config = config;
80 options.platform_config_map =
81 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
84 options.num_initial_load_threads = options.num_load_threads;
85 options.aspired_version_policy =
86 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
90 static bool UseTf1Model() {
return std::is_same<T, tf1_model_t>::value; }
92 ServerCore* GetServerCore() {
return this->server_core_.get(); }
94 const int64_t servable_version_ = kTestModelVersion;
97 static std::unique_ptr<ServerCore> server_core_;
100 template <
typename T>
101 std::unique_ptr<ServerCore> MultiInferenceTest<T>::server_core_;
103 TYPED_TEST_SUITE_P(MultiInferenceTest);
108 void AddInput(
const std::vector<std::pair<string, float>>& feature_kv,
109 MultiInferenceRequest* request) {
111 request->mutable_input()->mutable_example_list()->add_examples();
112 auto* features = example->mutable_features()->mutable_feature();
113 for (
const auto& feature : feature_kv) {
114 (*features)[feature.first].mutable_float_list()->add_value(feature.second);
118 void PopulateTask(
const string& signature_name,
const string& method_name,
119 int64_t version, InferenceTask* task) {
120 ModelSpec model_spec;
121 model_spec.set_name(kTestModelName);
123 model_spec.mutable_version()->set_value(version);
125 model_spec.set_signature_name(signature_name);
126 *task->mutable_model_spec() = model_spec;
127 task->set_method_name(method_name);
130 void ExpectStatusError(
const Status& status,
131 const tensorflow::errors::Code expected_code,
132 const string& message_substring) {
133 ASSERT_EQ(expected_code, status.code());
134 EXPECT_THAT(status.message(), ::testing::HasSubstr(message_substring));
140 TYPED_TEST_P(MultiInferenceTest, MissingInputTest) {
141 MultiInferenceRequest request;
142 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
144 MultiInferenceResponse response;
146 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
147 thread::ThreadPoolOptions(), request,
149 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
153 TYPED_TEST_P(MultiInferenceTest, UndefinedSignatureTest) {
154 MultiInferenceRequest request;
155 AddInput({{
"x", 2}}, &request);
156 PopulateTask(
"ThisSignatureDoesNotExist", kRegressMethodName, -1,
157 request.add_tasks());
159 MultiInferenceResponse response;
161 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
162 thread::ThreadPoolOptions(), request,
164 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
165 "signature not found");
169 TYPED_TEST_P(MultiInferenceTest, InconsistentModelSpecsInRequestTest) {
170 MultiInferenceRequest request;
171 AddInput({{
"x", 2}}, &request);
173 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
176 ModelSpec model_spec;
177 model_spec.set_name(
"ModelDoesNotExist");
178 model_spec.set_signature_name(
"regress_x_to_y");
179 auto* task = request.add_tasks();
180 *task->mutable_model_spec() = model_spec;
181 task->set_method_name(kRegressMethodName);
183 MultiInferenceResponse response;
185 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
186 thread::ThreadPoolOptions(), request,
188 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
189 "must access the same model name");
192 TYPED_TEST_P(MultiInferenceTest, EvaluateDuplicateSignaturesTest) {
193 MultiInferenceRequest request;
194 AddInput({{
"x", 2}}, &request);
195 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
197 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
199 MultiInferenceResponse response;
201 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
202 thread::ThreadPoolOptions(), request,
204 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
205 "Duplicate evaluation of signature: regress_x_to_y");
208 TYPED_TEST_P(MultiInferenceTest, UsupportedSignatureTypeTest) {
209 MultiInferenceRequest request;
210 AddInput({{
"x", 2}}, &request);
211 PopulateTask(
"serving_default", kPredictMethodName, -1, request.add_tasks());
213 MultiInferenceResponse response;
215 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
216 thread::ThreadPoolOptions(), request,
218 static_cast<absl::StatusCode
>(absl::StatusCode::kUnimplemented),
219 "Unsupported signature");
222 TYPED_TEST_P(MultiInferenceTest, ValidSingleSignatureTest) {
223 MultiInferenceRequest request;
224 AddInput({{
"x", 2}}, &request);
225 PopulateTask(
"regress_x_to_y", kRegressMethodName, this->servable_version_,
226 request.add_tasks());
228 MultiInferenceResponse expected_response;
229 auto* inference_result = expected_response.add_results();
230 auto* model_spec = inference_result->mutable_model_spec();
231 *model_spec = request.tasks(0).model_spec();
232 model_spec->mutable_version()->set_value(this->servable_version_);
233 auto* regression_result = inference_result->mutable_regression_result();
234 regression_result->add_regressions()->set_value(3.0);
236 MultiInferenceResponse response;
237 TF_ASSERT_OK(RunMultiInferenceWithServerCore(
238 RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
240 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
243 TYPED_TEST_P(MultiInferenceTest, MultipleValidRegressSignaturesTest) {
244 MultiInferenceRequest request;
245 AddInput({{
"x", 2}}, &request);
246 PopulateTask(
"regress_x_to_y", kRegressMethodName, this->servable_version_,
247 request.add_tasks());
248 PopulateTask(
"regress_x_to_y2", kRegressMethodName, this->servable_version_,
249 request.add_tasks());
251 MultiInferenceResponse expected_response;
254 auto* inference_result_1 = expected_response.add_results();
255 auto* model_spec_1 = inference_result_1->mutable_model_spec();
256 *model_spec_1 = request.tasks(0).model_spec();
257 model_spec_1->mutable_version()->set_value(this->servable_version_);
258 auto* regression_result_1 = inference_result_1->mutable_regression_result();
259 regression_result_1->add_regressions()->set_value(3.0);
262 auto* inference_result_2 = expected_response.add_results();
263 auto* model_spec_2 = inference_result_2->mutable_model_spec();
264 *model_spec_2 = request.tasks(1).model_spec();
265 model_spec_2->mutable_version()->set_value(this->servable_version_);
266 auto* regression_result_2 = inference_result_2->mutable_regression_result();
267 regression_result_2->add_regressions()->set_value(4.0);
269 MultiInferenceResponse response;
270 TF_ASSERT_OK(RunMultiInferenceWithServerCore(
271 RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
273 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
276 TYPED_TEST_P(MultiInferenceTest, RegressAndClassifySignaturesTest) {
277 MultiInferenceRequest request;
278 AddInput({{
"x", 2}}, &request);
279 PopulateTask(
"regress_x_to_y", kRegressMethodName, this->servable_version_,
280 request.add_tasks());
281 PopulateTask(
"classify_x_to_y", kClassifyMethodName, this->servable_version_,
282 request.add_tasks());
284 MultiInferenceResponse expected_response;
285 auto* inference_result_1 = expected_response.add_results();
286 auto* model_spec_1 = inference_result_1->mutable_model_spec();
287 *model_spec_1 = request.tasks(0).model_spec();
288 model_spec_1->mutable_version()->set_value(this->servable_version_);
289 auto* regression_result = inference_result_1->mutable_regression_result();
290 regression_result->add_regressions()->set_value(3.0);
292 auto* inference_result_2 = expected_response.add_results();
293 auto* model_spec_2 = inference_result_2->mutable_model_spec();
294 *model_spec_2 = request.tasks(1).model_spec();
295 model_spec_2->mutable_version()->set_value(this->servable_version_);
296 auto* classification_result =
297 inference_result_2->mutable_classification_result();
298 classification_result->add_classifications()->add_classes()->set_score(3.0);
300 MultiInferenceResponse response;
301 TF_ASSERT_OK(RunMultiInferenceWithServerCore(
302 RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(), request,
304 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
309 TYPED_TEST_P(MultiInferenceTest, ModelSpecOverride) {
310 MultiInferenceRequest request;
311 AddInput({{
"x", 2}}, &request);
312 PopulateTask(
"regress_x_to_y", kRegressMethodName, this->servable_version_,
313 request.add_tasks());
314 auto model_spec_override =
315 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
317 MultiInferenceResponse response;
318 EXPECT_NE(tensorflow::error::NOT_FOUND,
319 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
320 thread::ThreadPoolOptions(),
324 tensorflow::error::NOT_FOUND,
325 RunMultiInferenceWithServerCoreWithModelSpec(
326 RunOptions(), this->GetServerCore(), thread::ThreadPoolOptions(),
327 model_spec_override, request, &response)
331 TYPED_TEST_P(MultiInferenceTest, ThreadPoolOptions) {
332 MultiInferenceRequest request;
333 AddInput({{
"x", 2}}, &request);
334 PopulateTask(
"regress_x_to_y", kRegressMethodName, this->servable_version_,
335 request.add_tasks());
337 MultiInferenceResponse expected_response;
338 auto* inference_result = expected_response.add_results();
339 auto* model_spec = inference_result->mutable_model_spec();
340 *model_spec = request.tasks(0).model_spec();
341 model_spec->mutable_version()->set_value(this->servable_version_);
342 auto* regression_result = inference_result->mutable_regression_result();
343 regression_result->add_regressions()->set_value(3.0);
345 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
347 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
349 thread::ThreadPoolOptions thread_pool_options;
350 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
351 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
352 MultiInferenceResponse response;
354 RunMultiInferenceWithServerCore(RunOptions(), this->GetServerCore(),
355 thread_pool_options, request, &response));
356 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
359 ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
362 REGISTER_TYPED_TEST_SUITE_P(
363 MultiInferenceTest, MissingInputTest, UndefinedSignatureTest,
364 InconsistentModelSpecsInRequestTest, EvaluateDuplicateSignaturesTest,
365 UsupportedSignatureTypeTest, ValidSingleSignatureTest,
366 MultipleValidRegressSignaturesTest, RegressAndClassifySignaturesTest,
367 ModelSpecOverride, ThreadPoolOptions);
369 typedef ::testing::Types<tf1_model_t, tf2_model_t> ModelTypes;
370 INSTANTIATE_TYPED_TEST_SUITE_P(MultiInference, MultiInferenceTest, ModelTypes);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)