16 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/time/clock.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/cc/saved_model/loader.h"
23 #include "tensorflow/cc/saved_model/signature_constants.h"
24 #include "xla/tsl/lib/core/status_test_util.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/threadpool_options.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
31 #include "tensorflow/core/tfrt/utils/tensor_util.h"
32 #include "tensorflow_serving/core/availability_preserving_policy.h"
33 #include "tensorflow_serving/core/servable_handle.h"
34 #include "tensorflow_serving/model_servers/model_platform_types.h"
35 #include "tensorflow_serving/model_servers/platform_config_util.h"
36 #include "tensorflow_serving/model_servers/server_core.h"
37 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/servable.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/servables/tensorflow/test_util/mock_tfrt_saved_model.h"
41 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
42 #include "tensorflow_serving/test_util/test_util.h"
43 #include "tensorflow_serving/util/oss_or_google.h"
45 namespace tensorflow {
49 using ::testing::DoAll;
50 using ::testing::HasSubstr;
51 using ::testing::Return;
52 using ::testing::ReturnRef;
53 using ::testing::WithArgs;
55 constexpr
char kTestModelName[] =
"test_model";
56 constexpr
int kTestModelVersion = 123;
58 const char kInputTensorKey[] =
"x";
59 const char kOutputTensorKey[] =
"y";
61 class PredictImplTest :
public ::testing::Test {
63 static void SetUpTestSuite() {
64 tfrt_stub::SetGlobalRuntime(
65 tfrt_stub::Runtime::Create(4));
67 ModelServerConfig config;
68 auto model_config = config.mutable_model_config_list()->add_config();
69 model_config->set_name(kTestModelName);
70 model_config->set_base_path(
71 test_util::TestSrcDirPath(
"servables/tensorflow/testdata/"
72 "saved_model_half_plus_two_tf2_cpu"));
73 model_config->set_model_platform(kTensorFlowModelPlatform);
77 ServerCore::Options options;
78 options.model_server_config = config;
79 PlatformConfigMap platform_config_map;
80 ::google::protobuf::Any source_adapter_config;
81 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
82 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
83 (*(*platform_config_map
84 .mutable_platform_configs())[kTensorFlowModelPlatform]
85 .mutable_source_adapter_config()) = source_adapter_config;
86 options.platform_config_map = platform_config_map;
87 options.aspired_version_policy =
88 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
91 options.num_initial_load_threads = options.num_load_threads;
96 static void TearDownTestSuite() { saved_model_server_core_.reset(); }
99 Status GetSavedModelServableHandle(ServerCore* server_core,
100 ServableHandle<Servable>* servable) {
101 ModelSpec model_spec;
102 model_spec.set_name(kTestModelName);
103 return server_core->GetServableHandle(model_spec, servable);
106 ServerCore* GetServerCore() {
return saved_model_server_core_.get(); }
108 Status CallPredict(ServerCore* server_core,
const PredictRequest& request,
109 PredictResponse* response,
110 absl::Duration timeout = absl::ZeroDuration()) {
111 ServableHandle<Servable> servable;
112 TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &servable));
115 Servable::RunOptions run_options;
116 if (timeout != absl::ZeroDuration())
117 run_options.deadline = absl::Now() + timeout;
118 return servable->Predict(run_options, request, response);
122 static std::unique_ptr<ServerCore> saved_model_server_core_;
125 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
127 TEST_F(PredictImplTest, PredictionSuccess) {
128 PredictRequest request;
129 PredictResponse response;
131 ModelSpec* model_spec = request.mutable_model_spec();
132 model_spec->set_name(kTestModelName);
133 model_spec->mutable_version()->set_value(kTestModelVersion);
135 TensorProto tensor_proto;
136 tensor_proto.add_float_val(2.0);
137 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
138 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
140 TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
141 TensorProto output_tensor_proto;
142 output_tensor_proto.add_float_val(3);
143 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
144 output_tensor_proto.mutable_tensor_shape();
145 PredictResponse expected_response;
146 *expected_response.mutable_model_spec() = *model_spec;
147 expected_response.mutable_model_spec()->set_signature_name(
148 kDefaultServingSignatureDefKey);
149 (*expected_response.mutable_outputs())[kOutputTensorKey] =
151 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
154 TEST_F(PredictImplTest, PredictionSuccessWithDefaultInputs) {
155 PredictRequest request;
156 PredictResponse response;
158 ModelSpec* model_spec = request.mutable_model_spec();
159 model_spec->set_name(kTestModelName);
160 model_spec->mutable_version()->set_value(kTestModelVersion);
163 TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
164 TensorProto output_tensor_proto;
165 output_tensor_proto.add_float_val(2);
166 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
167 output_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
168 PredictResponse expected_response;
169 *expected_response.mutable_model_spec() = *model_spec;
170 expected_response.mutable_model_spec()->set_signature_name(
171 kDefaultServingSignatureDefKey);
172 (*expected_response.mutable_outputs())[kOutputTensorKey] =
174 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
177 TEST_F(PredictImplTest, PredictionInvalidTensor) {
178 PredictRequest request;
179 PredictResponse response;
181 ModelSpec* model_spec = request.mutable_model_spec();
182 model_spec->set_name(kTestModelName);
183 model_spec->mutable_version()->set_value(kTestModelVersion);
185 TensorProto tensor_proto;
186 tensor_proto.add_bool_val(
true);
187 tensor_proto.set_dtype(tensorflow::DT_BOOL);
188 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
190 auto status = CallPredict(GetServerCore(), request, &response);
191 EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
192 EXPECT_THAT(status.message(), HasSubstr(
"Expected input x to be float"));
195 TEST_F(PredictImplTest, PredictionMissingFunction) {
196 PredictRequest request;
197 PredictResponse response;
199 TensorProto tensor_proto;
200 tensor_proto.add_float_val(2.0);
201 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
202 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
204 std::unique_ptr<test_util::MockSavedModel> saved_model(
205 (
new test_util::MockSavedModel()));
206 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
208 .WillRepeatedly(Return(std::nullopt));
210 RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
211 saved_model.get(), request, &response);
212 EXPECT_EQ(status.code(), tensorflow::error::Code::FAILED_PRECONDITION);
213 EXPECT_THAT(status.message(), HasSubstr(
"not found"));
216 TEST_F(PredictImplTest, PredictionMissingInput) {
217 PredictRequest request;
218 request.mutable_model_spec()->set_name(kTestModelName);
219 PredictResponse response;
221 TensorProto tensor_proto;
222 tensor_proto.add_float_val(2.0);
223 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
224 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
226 std::unique_ptr<test_util::MockSavedModel> saved_model(
227 (
new test_util::MockSavedModel()));
228 tfrt::internal::Signature signature;
229 signature.input_names = {
"unknown"};
230 tfrt::FunctionMetadata function_metadata(&signature);
231 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
233 .WillRepeatedly(Return(function_metadata));
235 RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
236 saved_model.get(), request, &response);
237 EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
241 "Request inputs do not match required inputs for model "
242 "`test_model`. Send extra: {x}. Missing but required: {unknown}."));
245 TEST_F(PredictImplTest, PredictionRunError) {
246 PredictRequest request;
247 PredictResponse response;
249 TensorProto tensor_proto;
250 tensor_proto.add_float_val(2.0);
251 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
252 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
254 std::unique_ptr<test_util::MockSavedModel> saved_model(
255 (
new test_util::MockSavedModel()));
256 tfrt::internal::Signature signature;
257 signature.input_names = {
"x"};
258 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
259 signature.input_specs = {spec};
260 tfrt::FunctionMetadata function_metadata(&signature);
261 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
263 .WillRepeatedly(Return(function_metadata));
264 EXPECT_CALL(*saved_model,
265 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
267 .WillRepeatedly(Return(errors::InvalidArgument(
"test error")));
269 RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
270 saved_model.get(), request, &response);
271 EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
272 EXPECT_THAT(status.message(), HasSubstr(
"test error"));
275 TEST_F(PredictImplTest, PredictionUnmatchedOutputNumber) {
276 PredictRequest request;
277 PredictResponse response;
279 TensorProto tensor_proto;
280 tensor_proto.add_float_val(2.0);
281 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
282 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
284 std::unique_ptr<test_util::MockSavedModel> saved_model(
285 (
new test_util::MockSavedModel()));
286 tfrt::internal::Signature signature;
287 signature.input_names = {
"x"};
288 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
289 signature.input_specs = {spec};
290 tfrt::FunctionMetadata function_metadata(&signature);
291 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
293 .WillRepeatedly(Return(function_metadata));
296 EXPECT_CALL(*saved_model,
297 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
300 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
301 output_tensors->push_back(output);
302 output_tensors->push_back(output);
304 Return(absl::OkStatus())));
306 RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
307 saved_model.get(), request, &response);
308 EXPECT_EQ(status.code(), tensorflow::error::Code::UNKNOWN);
309 EXPECT_THAT(status.message(), HasSubstr(
"Predict internal error."));
312 TEST_F(PredictImplTest, OutputFilters) {
313 PredictRequest request;
314 PredictResponse response;
316 TensorProto tensor_proto;
317 tensor_proto.add_float_val(2.0);
318 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
319 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
320 request.add_output_filter(
"output1");
322 std::unique_ptr<test_util::MockSavedModel> saved_model(
323 (
new test_util::MockSavedModel()));
324 tfrt::internal::Signature signature;
325 signature.input_names = {
"x"};
326 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
327 signature.input_specs = {spec};
328 signature.output_names = {
"output1",
"output2"};
329 tfrt::FunctionMetadata function_metadata(&signature);
330 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
332 .WillRepeatedly(Return(function_metadata));
334 tensorflow::SignatureDef signature_def;
335 tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
336 tensor_info1.set_name(
"x");
337 tensor_info2.set_name(
"output1");
338 tensor_info3.set_name(
"output2");
339 signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
340 signature_def.mutable_outputs()->insert({
"output1", tensor_info2});
341 signature_def.mutable_outputs()->insert({
"output2", tensor_info3});
342 signature_def.set_method_name(
"tensorflow/serving/predict");
344 tensorflow::MetaGraphDef meta_graph_def;
345 meta_graph_def.mutable_signature_def()->insert(
346 {
"serving_default", signature_def});
347 EXPECT_CALL(*saved_model, GetMetaGraphDef())
349 .WillRepeatedly(ReturnRef(meta_graph_def));
351 TensorProto output_tensor_proto1;
352 output_tensor_proto1.add_float_val(1.0);
353 output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
354 output_tensor_proto1.mutable_tensor_shape();
358 RunByTensorNames(_, _, ::testing::SizeIs(1),
359 ::testing::An<absl::Span<const std::string>>(), _))
362 DoAll(WithArgs<4>([&](std::vector<Tensor>* output_tensors) {
363 Tensor output_tensor;
364 CHECK(output_tensor.FromProto(output_tensor_proto1));
365 output_tensors->push_back(output_tensor);
367 Return(absl::OkStatus())));
368 TF_EXPECT_OK(RunPredict(tfrt_stub::SavedModel::RunOptions(),
369 kTestModelVersion, saved_model.get(), request,
371 EXPECT_EQ(response.outputs_size(), 1);
372 EXPECT_TRUE(response.outputs().find(
"output1") != response.outputs().end());
373 EXPECT_THAT(response.outputs().at(
"output1"),
374 test_util::EqualsProto(output_tensor_proto1));
377 TEST_F(PredictImplTest, OutputFiltersFullSet) {
378 PredictRequest request;
379 PredictResponse response;
381 TensorProto tensor_proto;
382 tensor_proto.add_float_val(2.0);
383 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
384 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
385 request.add_output_filter(
"output1");
386 request.add_output_filter(
"output2");
388 std::unique_ptr<test_util::MockSavedModel> saved_model(
389 (
new test_util::MockSavedModel()));
390 tfrt::internal::Signature signature;
391 signature.input_names = {
"x"};
392 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
393 signature.input_specs = {spec};
394 signature.output_names = {
"output1",
"output2"};
395 tfrt::FunctionMetadata function_metadata(&signature);
396 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
398 .WillRepeatedly(Return(function_metadata));
402 EXPECT_CALL(*saved_model, GetMetaGraphDef()).Times(0);
405 RunByTensorNames(_, _, ::testing::SizeIs(1),
406 ::testing::An<absl::Span<const std::string>>(), _))
408 TensorProto output_tensor_proto1;
409 output_tensor_proto1.add_float_val(1.0);
410 output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
411 output_tensor_proto1.mutable_tensor_shape();
412 TensorProto output_tensor_proto2;
413 output_tensor_proto2.add_float_val(2.0);
414 output_tensor_proto2.set_dtype(tensorflow::DT_FLOAT);
415 output_tensor_proto2.mutable_tensor_shape();
416 EXPECT_CALL(*saved_model, Run(_, _, _, _))
419 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
420 Tensor output_tensor1;
421 CHECK(output_tensor1.FromProto(output_tensor_proto1));
422 output_tensors->push_back(output_tensor1);
423 Tensor output_tensor2;
424 CHECK(output_tensor2.FromProto(output_tensor_proto2));
425 output_tensors->push_back(output_tensor2);
427 Return(absl::OkStatus())));
429 TF_EXPECT_OK(RunPredict(tfrt_stub::SavedModel::RunOptions(),
430 kTestModelVersion, saved_model.get(), request,
432 EXPECT_EQ(response.outputs_size(), 2);
433 EXPECT_TRUE(response.outputs().find(
"output1") != response.outputs().end());
434 EXPECT_THAT(response.outputs().at(
"output1"),
435 test_util::EqualsProto(output_tensor_proto1));
436 EXPECT_TRUE(response.outputs().find(
"output2") != response.outputs().end());
437 EXPECT_THAT(response.outputs().at(
"output2"),
438 test_util::EqualsProto(output_tensor_proto2));
441 TEST_F(PredictImplTest, OutputFiltersWithDefaultInputs) {
442 PredictRequest request;
443 PredictResponse response;
445 request.add_output_filter(
"output1");
447 std::unique_ptr<test_util::MockSavedModel> saved_model(
448 (
new test_util::MockSavedModel()));
449 tfrt::internal::Signature signature;
450 signature.input_names = {
"x"};
451 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
452 signature.input_specs = {spec};
453 signature.output_names = {
"output1",
"output2"};
455 TensorProto tensor_proto;
456 tensor.AsProtoTensorContent(&tensor_proto);
457 signature.default_inputs[kInputTensorKey] = tensor_proto;
458 tfrt::FunctionMetadata function_metadata(&signature);
459 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
461 .WillRepeatedly(Return(function_metadata));
463 tensorflow::SignatureDef signature_def;
464 tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
465 tensor_info1.set_name(
"x");
466 tensor_info2.set_name(
"output1");
467 tensor_info3.set_name(
"output2");
468 signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
469 signature_def.mutable_outputs()->insert({
"output1", tensor_info2});
470 signature_def.mutable_outputs()->insert({
"output2", tensor_info3});
471 signature_def.set_method_name(
"tensorflow/serving/predict");
472 (*signature_def.mutable_defaults())[kInputTensorKey] = tensor_proto;
474 tensorflow::MetaGraphDef meta_graph_def;
475 meta_graph_def.mutable_signature_def()->insert(
476 {
"serving_default", signature_def});
477 EXPECT_CALL(*saved_model, GetMetaGraphDef())
479 .WillRepeatedly(ReturnRef(meta_graph_def));
481 TensorProto output_tensor_proto1;
482 output_tensor_proto1.add_float_val(1.0);
483 output_tensor_proto1.set_dtype(tensorflow::DT_FLOAT);
484 output_tensor_proto1.mutable_tensor_shape();
488 RunByTensorNames(_, ::testing::SizeIs(1), ::testing::SizeIs(1),
489 ::testing::An<absl::Span<const std::string>>(), _))
492 DoAll(WithArgs<4>([&](std::vector<Tensor>* output_tensors) {
493 Tensor output_tensor;
494 CHECK(output_tensor.FromProto(output_tensor_proto1));
495 output_tensors->push_back(output_tensor);
497 Return(absl::OkStatus())));
498 TF_EXPECT_OK(RunPredict(tfrt::SavedModel::RunOptions(), kTestModelVersion,
499 saved_model.get(), request, &response));
500 EXPECT_EQ(response.outputs_size(), 1);
501 EXPECT_TRUE(response.outputs().find(
"output1") != response.outputs().end());
502 EXPECT_THAT(response.outputs().at(
"output1"),
503 test_util::EqualsProto(output_tensor_proto1));
506 TEST_F(PredictImplTest, UnmatchedOutputFilters) {
507 PredictRequest request;
508 PredictResponse response;
510 TensorProto tensor_proto;
511 tensor_proto.add_float_val(2.0);
512 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
513 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
514 request.add_output_filter(
"output1");
515 request.add_output_filter(
"output3");
517 std::unique_ptr<test_util::MockSavedModel> saved_model(
518 (
new test_util::MockSavedModel()));
519 tfrt::internal::Signature signature;
520 signature.input_names = {
"x"};
521 tfrt::TensorSpec spec(tensorflow::DT_FLOAT);
522 signature.input_specs = {spec};
523 signature.output_names = {
"output1",
"output2"};
524 tfrt::FunctionMetadata function_metadata(&signature);
525 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
527 .WillRepeatedly(Return(function_metadata));
528 Tensor output_tensor;
530 tensorflow::SignatureDef signature_def;
531 tensorflow::TensorInfo tensor_info1, tensor_info2, tensor_info3;
532 tensor_info1.set_name(
"x");
533 tensor_info2.set_name(
"output1");
534 tensor_info3.set_name(
"output2");
535 signature_def.mutable_inputs()->insert({kInputTensorKey, tensor_info1});
536 signature_def.mutable_outputs()->insert({
"output1", tensor_info2});
537 signature_def.mutable_outputs()->insert({
"output2", tensor_info3});
538 signature_def.set_method_name(
"tensorflow/serving/predict");
540 tensorflow::MetaGraphDef meta_graph_def;
541 meta_graph_def.mutable_signature_def()->insert(
542 {
"serving_default", signature_def});
543 EXPECT_CALL(*saved_model, GetMetaGraphDef())
545 .WillRepeatedly(ReturnRef(meta_graph_def));
548 RunPredict(tfrt_stub::SavedModel::RunOptions(), kTestModelVersion,
549 saved_model.get(), request, &response);
550 EXPECT_EQ(status.code(), tensorflow::error::Code::INVALID_ARGUMENT);
553 HasSubstr(
"output tensor alias not found in signature: output3 Outputs "
554 "expected to be in the set {output1,output2}."));
557 TEST_F(PredictImplTest, PredictionTimeout) {
558 PredictRequest request;
559 PredictResponse response;
561 ModelSpec* model_spec = request.mutable_model_spec();
562 model_spec->set_name(kTestModelName);
563 model_spec->mutable_version()->set_value(kTestModelVersion);
565 TensorProto tensor_proto;
566 tensor_proto.add_float_val(2.0);
567 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
568 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
573 CallPredict(GetServerCore(), request, &response, absl::Nanoseconds(1));
575 EXPECT_EQ(status.code(), tensorflow::error::Code::DEADLINE_EXCEEDED);
576 EXPECT_EQ(status.message(),
"Deadline exceeded.");
static Status Create(Options options, std::unique_ptr< ServerCore > *core)