16 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/threadpool_options.h"
29 #include "tensorflow_serving/core/availability_preserving_policy.h"
30 #include "tensorflow_serving/model_servers/model_platform_types.h"
31 #include "tensorflow_serving/model_servers/platform_config_util.h"
32 #include "tensorflow_serving/model_servers/server_core.h"
33 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/util.h"
36 #include "tensorflow_serving/test_util/test_util.h"
37 #include "tensorflow_serving/util/oss_or_google.h"
39 namespace tensorflow {
43 constexpr
char kTestModelName[] =
"test_model";
44 constexpr
int kTestModelVersion = 123;
46 const char kInputTensorKey[] =
"x";
47 const char kOutputTensorKey[] =
"y";
50 class FakeSession :
public tensorflow::Session {
53 ~FakeSession()
override =
default;
54 Status Create(
const GraphDef& graph)
override {
55 return errors::Unimplemented(
"not available in fake");
57 Status Extend(
const GraphDef& graph)
override {
58 return errors::Unimplemented(
"not available in fake");
60 Status Close()
override {
61 return errors::Unimplemented(
"not available in fake");
63 Status ListDevices(std::vector<DeviceAttributes>* response)
override {
64 return errors::Unimplemented(
"not available in fake");
66 Status Run(
const std::vector<std::pair<string, Tensor>>& inputs,
67 const std::vector<string>& output_names,
68 const std::vector<string>& target_nodes,
69 std::vector<Tensor>* outputs)
override {
70 RunMetadata run_metadata;
71 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
74 Status Run(
const RunOptions& run_options,
75 const std::vector<std::pair<string, Tensor>>& inputs,
76 const std::vector<string>& output_names,
77 const std::vector<string>& target_nodes,
78 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override {
79 return Run(run_options, inputs, output_names, target_nodes, outputs,
80 run_metadata, thread::ThreadPoolOptions());
82 Status Run(
const RunOptions& run_options,
83 const std::vector<std::pair<string, Tensor>>& inputs,
84 const std::vector<string>& output_names,
85 const std::vector<string>& target_nodes,
86 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
87 const thread::ThreadPoolOptions& thread_pool_options)
override {
88 for (
const auto& t : inputs) {
89 outputs->push_back(t.second);
91 return absl::OkStatus();
95 class PredictImplTest :
public ::testing::Test {
97 static void SetUpTestSuite() {
98 if (!IsTensorflowServingOSS()) {
99 const string bad_half_plus_two_path = test_util::TestSrcDirPath(
100 "/servables/tensorflow/testdata/bad_half_plus_two");
101 TF_ASSERT_OK(CreateServerCore(bad_half_plus_two_path,
102 &saved_model_server_core_bad_model_));
105 TF_ASSERT_OK(CreateServerCore(test_util::TensorflowTestSrcDirPath(
106 "cc/saved_model/testdata/half_plus_two"),
107 &saved_model_server_core_));
108 TF_ASSERT_OK(CreateServerCore(
109 test_util::TestSrcDirPath(
110 "/servables/tensorflow/testdata/saved_model_counter"),
111 &saved_model_server_core_counter_model_));
114 static void TearDownTestSuite() {
115 saved_model_server_core_.reset();
116 saved_model_server_core_bad_model_.reset();
117 saved_model_server_core_counter_model_.reset();
121 static Status CreateServerCore(
const string& model_path,
122 std::unique_ptr<ServerCore>* server_core) {
123 ModelServerConfig config;
124 auto model_config = config.mutable_model_config_list()->add_config();
125 model_config->set_name(kTestModelName);
126 model_config->set_base_path(model_path);
127 model_config->set_model_platform(kTensorFlowModelPlatform);
131 ServerCore::Options options;
132 options.model_server_config = config;
133 options.platform_config_map =
134 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
135 options.aspired_version_policy =
136 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
139 options.num_initial_load_threads = options.num_load_threads;
143 ServerCore* GetServerCore() {
144 return saved_model_server_core_.get();
147 ServerCore* GetServerCoreWithBadModel() {
148 return saved_model_server_core_bad_model_.get();
151 ServerCore* GetServerCoreWithCounterModel() {
152 return saved_model_server_core_counter_model_.get();
155 Status GetSavedModelServableHandle(ServerCore* server_core,
156 ServableHandle<SavedModelBundle>* bundle) {
157 ModelSpec model_spec;
158 model_spec.set_name(kTestModelName);
159 return server_core->GetServableHandle(model_spec, bundle);
162 Status CallPredict(ServerCore* server_core,
const PredictRequest& request,
163 PredictResponse* response,
164 const thread::ThreadPoolOptions& thread_pool_options =
165 thread::ThreadPoolOptions()) {
166 ServableHandle<SavedModelBundle> bundle;
167 TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &bundle));
168 return RunPredict(GetRunOptions(), bundle->meta_graph_def,
169 kTestModelVersion, bundle->session.get(), request,
170 response, thread_pool_options);
173 RunOptions GetRunOptions() {
return RunOptions(); }
176 static std::unique_ptr<ServerCore> saved_model_server_core_;
177 static std::unique_ptr<ServerCore> saved_model_server_core_bad_model_;
178 static std::unique_ptr<ServerCore> saved_model_server_core_counter_model_;
181 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
182 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_bad_model_;
183 std::unique_ptr<ServerCore>
184 PredictImplTest::saved_model_server_core_counter_model_;
186 TEST_F(PredictImplTest, MissingOrEmptyModelSpec) {
187 PredictRequest request;
188 PredictResponse response;
191 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
192 CallPredict(GetServerCore(), request, &response).code());
194 ModelSpec* model_spec = request.mutable_model_spec();
195 model_spec->clear_name();
198 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
199 CallPredict(GetServerCore(), request, &response).code());
202 model_spec->set_name(
"test");
203 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
204 CallPredict(GetServerCore(), request, &response).code());
207 TEST_F(PredictImplTest, EmptyInputList) {
208 PredictRequest request;
209 PredictResponse response;
211 ModelSpec* model_spec = request.mutable_model_spec();
212 model_spec->set_name(kTestModelName);
213 model_spec->mutable_version()->set_value(kTestModelVersion);
216 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
217 CallPredict(GetServerCore(), request, &response).code());
220 TEST_F(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
221 PredictRequest request;
222 PredictResponse response;
223 auto inputs = request.mutable_inputs();
225 ModelSpec* model_spec = request.mutable_model_spec();
226 model_spec->set_name(kTestModelName);
227 model_spec->mutable_version()->set_value(kTestModelVersion);
229 TensorProto tensor_proto1;
230 tensor_proto1.add_string_val(
"any_value");
231 tensor_proto1.set_dtype(tensorflow::DT_STRING);
232 tensor_proto1.mutable_tensor_shape()->add_dim()->set_size(1);
233 (*inputs)[
"unknown_key1"] = tensor_proto1;
235 TensorProto tensor_proto2;
236 tensor_proto2.add_float_val(1.0);
237 tensor_proto2.set_dtype(tensorflow::DT_FLOAT);
238 tensor_proto2.mutable_tensor_shape()->add_dim()->set_size(1);
239 (*inputs)[
"unknown_key2"] = tensor_proto2;
241 Status status = CallPredict(GetServerCore(), request, &response);
242 EXPECT_EQ(status.code(),
243 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
244 EXPECT_THAT(status.message(),
245 ::testing::HasSubstr(
"Sent extra: {unknown_key1,unknown_key2}"));
246 EXPECT_THAT(status.message(),
247 ::testing::HasSubstr(absl::StrCat(
"Missing but required: {",
248 kInputTensorKey,
"}")));
251 TEST_F(PredictImplTest, PredictionInvalidTensor) {
252 PredictRequest request;
253 PredictResponse response;
255 ModelSpec* model_spec = request.mutable_model_spec();
256 model_spec->set_name(kTestModelName);
257 model_spec->mutable_version()->set_value(kTestModelVersion);
259 TensorProto tensor_proto;
260 tensor_proto.add_bool_val(
true);
261 tensor_proto.set_dtype(tensorflow::DT_BOOL);
262 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
264 auto status = CallPredict(GetServerCore(), request, &response);
265 EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
268 ::testing::HasSubstr(
"Expects arg[0] to be float but bool is provided"));
271 TEST_F(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
272 PredictRequest request;
273 PredictResponse response;
275 ModelSpec* model_spec = request.mutable_model_spec();
276 model_spec->set_name(kTestModelName);
277 model_spec->mutable_version()->set_value(kTestModelVersion);
279 TensorProto tensor_proto;
280 tensor_proto.add_float_val(2.0);
281 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
282 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
283 request.add_output_filter(
"output_filter");
286 Status status1 = CallPredict(GetServerCore(), request, &response);
287 EXPECT_EQ(status1.code(),
288 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
289 EXPECT_THAT(status1.message(),
290 ::testing::HasSubstr(
291 "output tensor alias not found in signature: output_filter"));
293 request.clear_output_filter();
294 request.add_output_filter(kOutputTensorKey);
295 TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
296 request.add_output_filter(kOutputTensorKey);
299 Status status2 = CallPredict(GetServerCore(), request, &response);
300 EXPECT_EQ(status2.code(),
301 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
302 EXPECT_THAT(status2.message(),
303 ::testing::HasSubstr(
"duplicate output tensor alias: y"));
306 TEST_F(PredictImplTest, InputTensorsHaveWrongType) {
307 PredictRequest request;
308 PredictResponse response;
310 ModelSpec* model_spec = request.mutable_model_spec();
311 model_spec->set_name(kTestModelName);
312 model_spec->mutable_version()->set_value(kTestModelVersion);
314 TensorProto tensor_proto;
315 tensor_proto.add_string_val(
"any_value");
316 tensor_proto.set_dtype(tensorflow::DT_STRING);
317 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
318 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
319 request.add_output_filter(kOutputTensorKey);
322 Status status = CallPredict(GetServerCore(), request, &response);
323 EXPECT_EQ(status.code(),
324 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
325 EXPECT_THAT(status.message(),
326 ::testing::HasSubstr(
"to be float but string is provided"));
329 TEST_F(PredictImplTest, ModelMissingSignatures) {
330 if (IsTensorflowServingOSS()) {
333 PredictRequest request;
334 PredictResponse response;
336 ModelSpec* model_spec = request.mutable_model_spec();
337 model_spec->set_name(kTestModelName);
338 model_spec->mutable_version()->set_value(kTestModelVersion);
341 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION,
342 CallPredict(GetServerCoreWithBadModel(),
343 request, &response).code());
346 TEST_F(PredictImplTest, PredictionSuccess) {
347 PredictRequest request;
348 PredictResponse response;
350 ModelSpec* model_spec = request.mutable_model_spec();
351 model_spec->set_name(kTestModelName);
352 model_spec->mutable_version()->set_value(kTestModelVersion);
354 TensorProto tensor_proto;
355 tensor_proto.add_float_val(2.0);
356 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
357 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
359 TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
360 TensorProto output_tensor_proto;
361 output_tensor_proto.add_float_val(3);
362 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
363 output_tensor_proto.mutable_tensor_shape();
364 PredictResponse expected_response;
365 *expected_response.mutable_model_spec() = *model_spec;
366 expected_response.mutable_model_spec()->set_signature_name(
367 kDefaultServingSignatureDefKey);
368 (*expected_response.mutable_outputs())[kOutputTensorKey] =
370 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
374 TEST_F(PredictImplTest, PredictionWithNamedRegressionSignature) {
375 PredictRequest request;
376 PredictResponse response;
378 ModelSpec* model_spec = request.mutable_model_spec();
379 model_spec->set_name(kTestModelName);
380 model_spec->mutable_version()->set_value(kTestModelVersion);
381 model_spec->set_signature_name(
"regress_x2_to_y3");
383 TensorProto tensor_proto;
384 tensor_proto.add_float_val(2.0);
385 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
386 (*request.mutable_inputs())[kRegressInputs] = tensor_proto;
387 TF_ASSERT_OK(CallPredict(GetServerCore(), request, &response));
388 TensorProto output_tensor_proto;
389 output_tensor_proto.add_float_val(4);
390 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
391 output_tensor_proto.mutable_tensor_shape();
392 PredictResponse expected_response;
393 *expected_response.mutable_model_spec() = *model_spec;
394 (*expected_response.mutable_outputs())[kRegressOutputs] = output_tensor_proto;
395 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
400 TEST_F(PredictImplTest, PredictionWithNamedClassificationSignature) {
401 PredictRequest request;
402 PredictResponse response;
404 ModelSpec* model_spec = request.mutable_model_spec();
405 model_spec->set_name(kTestModelName);
406 model_spec->mutable_version()->set_value(kTestModelVersion);
407 model_spec->set_signature_name(
"classify_x2_to_y3");
409 TensorProto tensor_proto;
410 tensor_proto.add_float_val(2.0);
411 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
412 (*request.mutable_inputs())[kClassifyInputs] = tensor_proto;
414 TF_ASSERT_OK(CallPredict(GetServerCore(), request, &response));
415 TensorProto output_tensor_proto;
416 output_tensor_proto.add_float_val(4);
417 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
418 output_tensor_proto.mutable_tensor_shape();
419 PredictResponse expected_response;
420 *expected_response.mutable_model_spec() = *model_spec;
421 (*expected_response.mutable_outputs())[kClassifyOutputScores] =
423 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
433 TEST_F(PredictImplTest, PredictionWithCustomizedSignatures) {
434 PredictRequest request;
435 PredictResponse response;
438 ModelSpec* model_spec = request.mutable_model_spec();
439 model_spec->set_name(kTestModelName);
440 model_spec->mutable_version()->set_value(kTestModelVersion);
441 model_spec->set_signature_name(
"get_counter");
443 TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
444 request, &response));
446 PredictResponse expected_get_counter;
447 *expected_get_counter.mutable_model_spec() = *model_spec;
448 TensorProto output_get_counter;
449 output_get_counter.add_float_val(0);
450 output_get_counter.set_dtype(tensorflow::DT_FLOAT);
451 output_get_counter.mutable_tensor_shape();
452 (*expected_get_counter.mutable_outputs())[
"output"] = output_get_counter;
453 EXPECT_THAT(response, test_util::EqualsProto(expected_get_counter));
456 model_spec->set_signature_name(
"incr_counter");
457 TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
458 request, &response));
460 PredictResponse expected_incr_counter;
461 *expected_incr_counter.mutable_model_spec() = *model_spec;
462 TensorProto output_incr_counter;
463 output_incr_counter.add_float_val(1);
464 output_incr_counter.set_dtype(tensorflow::DT_FLOAT);
465 output_incr_counter.mutable_tensor_shape();
466 (*expected_incr_counter.mutable_outputs())[
"output"] = output_incr_counter;
467 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter));
470 model_spec->set_signature_name(
"reset_counter");
471 TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
472 request, &response));
474 PredictResponse expected_reset_counter;
475 *expected_reset_counter.mutable_model_spec() = *model_spec;
476 TensorProto output_reset_counter;
477 output_reset_counter.add_float_val(0);
478 output_reset_counter.set_dtype(tensorflow::DT_FLOAT);
479 output_reset_counter.mutable_tensor_shape();
480 (*expected_reset_counter.mutable_outputs())[
"output"] = output_reset_counter;
481 EXPECT_THAT(response, test_util::EqualsProto(expected_reset_counter));
484 model_spec->set_signature_name(
"incr_counter");
485 request.add_output_filter(
"output");
486 TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
487 request, &response));
488 request.clear_output_filter();
490 PredictResponse expected_incr_counter2;
491 *expected_incr_counter2.mutable_model_spec() = *model_spec;
492 TensorProto output_incr_counter2;
493 output_incr_counter2.add_float_val(1);
494 output_incr_counter2.set_dtype(tensorflow::DT_FLOAT);
495 output_incr_counter2.mutable_tensor_shape();
496 (*expected_incr_counter2.mutable_outputs())[
"output"] = output_incr_counter2;
497 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter2));
500 model_spec->set_signature_name(
"incr_counter_by");
501 TensorProto tensor_proto;
502 tensor_proto.add_float_val(3);
503 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
504 (*request.mutable_inputs())[
"delta"] = tensor_proto;
506 TF_ASSERT_OK(CallPredict(GetServerCoreWithCounterModel(),
507 request, &response));
509 PredictResponse expected_incr_counter_by;
510 *expected_incr_counter_by.mutable_model_spec() = *model_spec;
511 TensorProto output_incr_counter_by;
512 output_incr_counter_by.add_float_val(4);
513 output_incr_counter_by.set_dtype(tensorflow::DT_FLOAT);
514 output_incr_counter_by.mutable_tensor_shape();
515 (*expected_incr_counter_by.mutable_outputs())[
"output"] =
516 output_incr_counter_by;
517 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter_by));
520 TEST_F(PredictImplTest, ThreadPoolOptions) {
521 PredictRequest request;
522 PredictResponse response;
524 ModelSpec* model_spec = request.mutable_model_spec();
525 model_spec->set_name(kTestModelName);
526 model_spec->mutable_version()->set_value(kTestModelVersion);
528 TensorProto tensor_proto;
529 tensor_proto.add_float_val(2.0);
530 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
531 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
533 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
535 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
537 thread::ThreadPoolOptions thread_pool_options;
538 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
539 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
541 CallPredict(GetServerCore(), request, &response, thread_pool_options));
542 TensorProto output_tensor_proto;
543 output_tensor_proto.add_float_val(3);
544 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
545 output_tensor_proto.mutable_tensor_shape();
546 PredictResponse expected_response;
547 *expected_response.mutable_model_spec() = *model_spec;
548 expected_response.mutable_model_spec()->set_signature_name(
549 kDefaultServingSignatureDefKey);
550 (*expected_response.mutable_outputs())[kOutputTensorKey] =
552 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
555 ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
558 TEST_F(PredictImplTest, MethodNameCheck) {
559 ServableHandle<SavedModelBundle> bundle;
560 TF_ASSERT_OK(GetSavedModelServableHandle(GetServerCore(), &bundle));
561 MetaGraphDef meta_graph_def = bundle->meta_graph_def;
562 auto* signature_defs = meta_graph_def.mutable_signature_def();
564 PredictRequest request;
565 ModelSpec* model_spec = request.mutable_model_spec();
566 model_spec->set_name(kTestModelName);
567 model_spec->mutable_version()->set_value(kTestModelVersion);
568 TensorProto tensor_proto;
569 tensor_proto.add_float_val(2.0);
570 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
571 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
573 FakeSession fake_session;
574 PredictResponse response;
576 bool old_val = GetSignatureMethodNameCheckFeature();
578 SetSignatureMethodNameCheckFeature(
true);
580 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
582 TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
583 &fake_session, request, &response,
584 thread::ThreadPoolOptions()));
586 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
587 "not/supported/method");
588 EXPECT_FALSE(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
589 &fake_session, request, &response,
590 thread::ThreadPoolOptions())
593 SetSignatureMethodNameCheckFeature(
false);
594 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
596 TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
597 &fake_session, request, &response,
598 thread::ThreadPoolOptions()));
600 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
601 "not/supported/method");
602 TF_EXPECT_OK(RunPredict(GetRunOptions(), meta_graph_def, kTestModelVersion,
603 &fake_session, request, &response,
604 thread::ThreadPoolOptions()));
606 SetSignatureMethodNameCheckFeature(old_val);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)