16 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
24 #include "google/protobuf/map.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/example/example.pb.h"
27 #include "tensorflow/core/example/feature.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/tfrt/utils/tensor_util.h"
36 #include "tsl/platform/env.h"
37 #include "tensorflow_serving/apis/classification.pb.h"
38 #include "tensorflow_serving/apis/input.pb.h"
39 #include "tensorflow_serving/apis/model.pb.h"
40 #include "tensorflow_serving/config/model_server_config.pb.h"
41 #include "tensorflow_serving/core/availability_preserving_policy.h"
42 #include "tensorflow_serving/core/servable_handle.h"
43 #include "tensorflow_serving/model_servers/model_platform_types.h"
44 #include "tensorflow_serving/model_servers/platform_config_util.h"
45 #include "tensorflow_serving/model_servers/server_core.h"
46 #include "tensorflow_serving/servables/tensorflow/servable.h"
47 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
48 #include "tensorflow_serving/servables/tensorflow/test_util/mock_tfrt_saved_model.h"
49 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
50 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
51 #include "tensorflow_serving/test_util/test_util.h"
53 namespace tensorflow {
58 using ::testing::DoAll;
59 using ::testing::HasSubstr;
60 using ::testing::Return;
61 using ::testing::WithArgs;
63 constexpr
char kTestModelName[] =
"test_model";
64 constexpr
int kTestModelVersion = 123;
66 class TfrtClassifierTest :
public ::testing::Test {
68 static void SetUpTestSuite() {
69 tfrt_stub::SetGlobalRuntime(
70 tfrt_stub::Runtime::Create(4));
73 void SetUp()
override {
74 ModelServerConfig config;
75 auto model_config = config.mutable_model_config_list()->add_config();
76 model_config->set_name(kTestModelName);
77 model_config->set_base_path(
78 test_util::TestSrcDirPath(
"servables/tensorflow/"
79 "testdata/saved_model_half_plus_two_cpu"));
80 model_config->set_model_platform(kTensorFlowModelPlatform);
84 ServerCore::Options options;
85 options.model_server_config = config;
86 PlatformConfigMap platform_config_map;
87 ::google::protobuf::Any source_adapter_config;
88 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
89 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
90 (*(*platform_config_map
91 .mutable_platform_configs())[kTensorFlowModelPlatform]
92 .mutable_source_adapter_config()) = source_adapter_config;
93 options.platform_config_map = platform_config_map;
94 options.aspired_version_policy =
95 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
98 options.num_initial_load_threads = options.num_load_threads;
101 request_ = test_util::CreateProto<ClassificationRequest>(
103 " name: \"test_model\""
104 " signature_name: \"classify_x_to_y\""
124 static void TearDownTestSuite() { server_core_ =
nullptr; }
127 Status GetSavedModelServableHandle(ServerCore* server_core,
128 ServableHandle<Servable>* servable) {
129 ModelSpec model_spec;
130 model_spec.set_name(kTestModelName);
131 return server_core->GetServableHandle(model_spec, servable);
134 Status CallClassify(ServerCore* server_core,
135 const ClassificationRequest& request,
136 ClassificationResponse* response) {
137 ServableHandle<Servable> servable;
138 TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &servable));
139 return servable->Classify({}, request, response);
143 std::unique_ptr<ClassifierInterface> classifier_;
144 static std::unique_ptr<ServerCore> server_core_;
145 ClassificationRequest request_;
148 std::unique_ptr<ServerCore> TfrtClassifierTest::server_core_;
150 TEST_F(TfrtClassifierTest, Basic) {
151 auto request = test_util::CreateProto<ClassificationRequest>(
153 " name: \"test_model\""
154 " signature_name: \"classify_x_to_y\""
172 " value: [ \"pt_BR\" ]"
200 ClassificationResponse response;
202 TF_EXPECT_OK(CallClassify(server_core_.get(), request, &response));
203 EXPECT_THAT(response,
204 test_util::EqualsProto(
205 "result { classifications { classes { "
206 "score: 42 } } classifications { classes {score: 12 } } }"
208 " name: \"test_model\""
209 " signature_name: \"classify_x_to_y\""
210 " version { value: 123 }"
214 TEST_F(TfrtClassifierTest, BasicWithContext) {
215 auto request = test_util::CreateProto<ClassificationRequest>(
217 " name: \"test_model\""
218 " signature_name: \"classify_x_to_y\""
221 " example_list_with_context {"
236 " value: [ \"pt_BR\" ]"
276 ClassificationResponse response;
278 TF_EXPECT_OK(CallClassify(server_core_.get(), request, &response));
279 EXPECT_THAT(response, test_util::EqualsProto(
280 "result { classifications { classes { score: 42 }} "
281 "classifications { classes { score: 12 }}}"
283 " name: \"test_model\""
284 " signature_name: \"classify_x_to_y\""
285 " version { value: 123 }"
289 TEST_F(TfrtClassifierTest, EmptyExampleList) {
290 auto request = test_util::CreateProto<ClassificationRequest>(
292 " name: \"test_model\""
293 " signature_name: \"classify_x_to_y\""
299 ClassificationResponse response;
301 Status status = CallClassify(server_core_.get(), request, &response);
302 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
303 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
306 TEST_F(TfrtClassifierTest, EmptyExampleListWithContext) {
307 auto request = test_util::CreateProto<ClassificationRequest>(
309 " name: \"test_model\""
310 " signature_name: \"classify_x_to_y\""
313 " example_list_with_context {"
328 ClassificationResponse response;
330 Status status = CallClassify(server_core_.get(), request, &response);
331 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
332 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
335 TEST_F(TfrtClassifierTest, EmptyInput) {
336 auto request = test_util::CreateProto<ClassificationRequest>(
338 " name: \"test_model\""
339 " signature_name: \"classify_x_to_y\""
343 ClassificationResponse response;
345 Status status = CallClassify(server_core_.get(), request, &response);
346 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
347 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
350 TEST_F(TfrtClassifierTest, InvalidFunctionName) {
351 ClassificationResponse response;
352 std::unique_ptr<test_util::MockSavedModel> saved_model(
353 (
new test_util::MockSavedModel()));
354 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
356 .WillRepeatedly(Return(std::nullopt));
357 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
358 saved_model.get(), request_, &response);
359 EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
360 EXPECT_THAT(status.message(), HasSubstr(
"not found"));
363 TEST_F(TfrtClassifierTest, InvalidFunctionUnmatchedInputSize) {
364 ClassificationResponse response;
365 std::unique_ptr<test_util::MockSavedModel> saved_model(
366 (
new test_util::MockSavedModel()));
367 tfrt::internal::Signature signature;
368 signature.input_names = {kClassifyInputs,
"wrong input"};
369 signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
370 tfrt::FunctionMetadata function_metadata(&signature);
371 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
373 .WillRepeatedly(Return(function_metadata));
374 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
375 saved_model.get(), request_, &response);
376 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
377 EXPECT_THAT(status.message(), HasSubstr(
"Expected one input Tensor."));
380 TEST_F(TfrtClassifierTest, InvalidFunctionUnmatchedOutputSize) {
381 ClassificationResponse response;
382 std::unique_ptr<test_util::MockSavedModel> saved_model(
383 (
new test_util::MockSavedModel()));
384 tfrt::internal::Signature signature;
385 signature.input_names = {kClassifyInputs};
386 signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores,
388 tfrt::FunctionMetadata function_metadata(&signature);
389 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
391 .WillRepeatedly(Return(function_metadata));
392 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
393 saved_model.get(), request_, &response);
394 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
395 EXPECT_THAT(status.message(),
396 HasSubstr(
"Expected one or two output Tensors"));
399 TEST_F(TfrtClassifierTest, InvalidFunctionInvalidInputName) {
400 ClassificationResponse response;
401 std::unique_ptr<test_util::MockSavedModel> saved_model(
402 (
new test_util::MockSavedModel()));
403 tfrt::internal::Signature signature;
404 signature.input_names = {
"wrong input"};
405 signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
406 tfrt::FunctionMetadata function_metadata(&signature);
407 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
409 .WillRepeatedly(Return(function_metadata));
410 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
411 saved_model.get(), request_, &response);
412 EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
415 HasSubstr(
"No classification inputs found in function's metadata"));
418 TEST_F(TfrtClassifierTest, InvalidFunctionInvalidOutputName) {
419 ClassificationResponse response;
420 std::unique_ptr<test_util::MockSavedModel> saved_model(
421 (
new test_util::MockSavedModel()));
422 tfrt::internal::Signature signature;
423 signature.input_names = {kClassifyInputs};
424 signature.output_names = {
"wrong output", kClassifyOutputScores};
425 tfrt::FunctionMetadata function_metadata(&signature);
426 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
428 .WillRepeatedly(Return(function_metadata));
429 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
430 saved_model.get(), request_, &response);
431 EXPECT_EQ(status.code(), absl::StatusCode::kFailedPrecondition);
432 EXPECT_THAT(status.message(),
433 HasSubstr(
"Expected classification function outputs to contain"));
436 TEST_F(TfrtClassifierTest, RunsFails) {
437 ClassificationResponse response;
438 std::unique_ptr<test_util::MockSavedModel> saved_model(
439 (
new test_util::MockSavedModel()));
440 tfrt::internal::Signature signature;
441 signature.input_names = {kClassifyInputs};
442 signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
443 tfrt::FunctionMetadata function_metadata(&signature);
444 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
446 .WillRepeatedly(Return(function_metadata));
447 EXPECT_CALL(*saved_model,
448 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
450 .WillRepeatedly(Return(errors::InvalidArgument(
"test error")));
451 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
452 saved_model.get(), request_, &response);
453 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
454 EXPECT_THAT(status.message(), HasSubstr(
"test error"));
457 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorNumber) {
458 ClassificationResponse response;
459 std::unique_ptr<test_util::MockSavedModel> saved_model(
460 (
new test_util::MockSavedModel()));
461 tfrt::internal::Signature signature;
462 signature.input_names = {kClassifyInputs};
463 signature.output_names = {kClassifyOutputClasses, kClassifyOutputScores};
464 tfrt::FunctionMetadata function_metadata(&signature);
465 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
467 .WillRepeatedly(Return(function_metadata));
469 EXPECT_CALL(*saved_model,
470 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
473 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
474 output_tensors->push_back(output);
476 Return(absl::OkStatus())));
477 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
478 saved_model.get(), request_, &response);
479 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
480 EXPECT_THAT(status.message(), HasSubstr(
"Unexpected output tensors size"));
483 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorShape) {
484 ClassificationResponse response;
485 std::unique_ptr<test_util::MockSavedModel> saved_model(
486 (
new test_util::MockSavedModel()));
487 tfrt::internal::Signature signature;
488 signature.input_names = {kClassifyInputs};
489 signature.output_names = {kClassifyOutputScores};
490 tfrt::FunctionMetadata function_metadata(&signature);
491 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
493 .WillRepeatedly(Return(function_metadata));
494 Tensor output(DT_FLOAT, TensorShape({1, 1, 1}));
495 EXPECT_CALL(*saved_model,
496 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
499 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
500 output_tensors->push_back(output);
502 Return(absl::OkStatus())));
503 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
504 saved_model.get(), request_, &response);
505 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
506 EXPECT_THAT(status.message(), HasSubstr(
"Expected Tensor shape"));
509 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorType) {
510 ClassificationResponse response;
511 std::unique_ptr<test_util::MockSavedModel> saved_model(
512 (
new test_util::MockSavedModel()));
513 tfrt::internal::Signature signature;
514 signature.input_names = {kClassifyInputs};
515 signature.output_names = {kClassifyOutputScores};
516 tfrt::FunctionMetadata function_metadata(&signature);
517 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
519 .WillRepeatedly(Return(function_metadata));
520 Tensor output(DT_STRING, TensorShape({1, 1}));
521 EXPECT_CALL(*saved_model,
522 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
525 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
526 output_tensors->push_back(output);
528 Return(absl::OkStatus())));
529 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
530 saved_model.get(), request_, &response);
531 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
532 EXPECT_THAT(status.message(),
533 HasSubstr(
"Expected scores Tensor of DT_FLOAT"));
536 TEST_F(TfrtClassifierTest, UnexpectedOutputTensorSize) {
537 ClassificationResponse response;
538 std::unique_ptr<test_util::MockSavedModel> saved_model(
539 (
new test_util::MockSavedModel()));
540 tfrt::internal::Signature signature;
541 signature.input_names = {kClassifyInputs};
542 signature.output_names = {kClassifyOutputScores};
543 tfrt::FunctionMetadata function_metadata(&signature);
544 EXPECT_CALL(*saved_model, GetFunctionMetadata(_))
546 .WillRepeatedly(Return(function_metadata));
547 Tensor output(DT_FLOAT, TensorShape({10, 1}));
548 EXPECT_CALL(*saved_model,
549 Run(_, _, ::testing::An<absl::Span<const Tensor>>(), _))
552 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
553 output_tensors->push_back(output);
555 Return(absl::OkStatus())));
556 auto status = RunClassify(tfrt::SavedModel::RunOptions(), kTestModelVersion,
557 saved_model.get(), request_, &response);
558 EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
559 EXPECT_THAT(status.message(),
560 HasSubstr(
"Expected scores output batch size of"));
static Status Create(Options options, std::unique_ptr< ServerCore > *core)