16 #include "tensorflow_serving/servables/tensorflow/classifier.h"
24 #include "google/protobuf/map.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/example/feature.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/threadpool_options.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/public/session.h"
36 #include "tensorflow_serving/apis/classification.pb.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/core/test_util/mock_session.h"
40 #include "tensorflow_serving/servables/tensorflow/util.h"
41 #include "tensorflow_serving/test_util/test_util.h"
43 namespace tensorflow {
47 using test_util::EqualsProto;
48 using test_util::MockSession;
51 const char kInputTensor[] =
"input:0";
52 const char kClassTensor[] =
"output:0";
53 const char kOutputPlusOneClassTensor[] =
"outputPlusOne:0";
54 const char kClassFeature[] =
"class";
55 const char kScoreTensor[] =
"score:0";
56 const char kScoreFeature[] =
"score";
57 const char kImproperlySizedScoresTensor[] =
"ImproperlySizedScores:0";
59 const char kOutputPlusOneSignature[] =
"output_plus_one";
60 const char kInvalidNamedSignature[] =
"invalid_regression_signature";
61 const char kImproperlySizedScoresSignature[] =
"ImproperlySizedScoresSignature";
67 class FakeSession :
public tensorflow::Session {
69 explicit FakeSession(absl::optional<int64_t> expected_timeout)
70 : expected_timeout_(expected_timeout) {}
71 ~FakeSession()
override =
default;
72 Status Create(
const GraphDef& graph)
override {
73 return errors::Unimplemented(
"not available in fake");
75 Status Extend(
const GraphDef& graph)
override {
76 return errors::Unimplemented(
"not available in fake");
79 Status Close()
override {
80 return errors::Unimplemented(
"not available in fake");
83 Status ListDevices(std::vector<DeviceAttributes>* response)
override {
84 return errors::Unimplemented(
"not available in fake");
87 Status Run(
const std::vector<std::pair<string, Tensor>>& inputs,
88 const std::vector<string>& output_names,
89 const std::vector<string>& target_nodes,
90 std::vector<Tensor>* outputs)
override {
91 if (expected_timeout_) {
92 LOG(FATAL) <<
"Run() without RunOptions not expected to be called";
94 RunMetadata run_metadata;
95 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
99 Status Run(
const RunOptions& run_options,
100 const std::vector<std::pair<string, Tensor>>& inputs,
101 const std::vector<string>& output_names,
102 const std::vector<string>& target_nodes,
103 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override {
104 return Run(run_options, inputs, output_names, target_nodes, outputs,
105 run_metadata, thread::ThreadPoolOptions());
108 Status Run(
const RunOptions& run_options,
109 const std::vector<std::pair<string, Tensor>>& inputs,
110 const std::vector<string>& output_names,
111 const std::vector<string>& target_nodes,
112 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
113 const thread::ThreadPoolOptions& thread_pool_options)
override {
114 if (expected_timeout_) {
115 CHECK_EQ(*expected_timeout_, run_options.timeout_in_ms());
117 if (inputs.size() != 1 || inputs[0].first != kInputTensor) {
118 return errors::Internal(
"Expected one input Tensor.");
121 const Tensor& input = inputs[0].second;
122 std::vector<Example> examples;
123 TF_RETURN_IF_ERROR(GetExamples(input, &examples));
127 GetClassTensor(examples, output_names, &classes, &scores));
128 for (
const auto& output_name : output_names) {
129 if (output_name == kClassTensor) {
130 outputs->push_back(classes);
131 }
else if (output_name == kScoreTensor ||
132 output_name == kOutputPlusOneClassTensor) {
133 outputs->push_back(scores);
134 }
else if (output_name == kImproperlySizedScoresTensor) {
137 outputs->emplace_back(DT_FLOAT, TensorShape({scores.dim_size(0),
138 scores.dim_size(1), 10}));
142 return absl::OkStatus();
146 static Status GetExamples(
const Tensor& input,
147 std::vector<Example>* examples) {
149 const int batch_size = input.dim_size(0);
150 const auto& flat_input = input.flat<tstring>();
151 for (
int i = 0; i < batch_size; ++i) {
153 if (!example.ParseFromArray(flat_input(i).data(), flat_input(i).size())) {
154 return errors::Internal(
"failed to parse example");
156 examples->push_back(example);
158 return absl::OkStatus();
163 static Feature GetFeature(
const Example& example,
const string& name) {
164 const auto it = example.features().feature().find(name);
165 if (it != example.features().feature().end()) {
172 static int FeatureSize(
const Feature& feature) {
173 if (feature.has_float_list()) {
174 return feature.float_list().value_size();
175 }
else if (feature.has_int64_list()) {
176 return feature.int64_list().value_size();
177 }
else if (feature.has_bytes_list()) {
178 return feature.bytes_list().value_size();
186 static Status GetClassTensor(
const std::vector<Example>& examples,
187 const std::vector<string>& output_names,
188 Tensor* classes, Tensor* scores) {
189 if (examples.empty()) {
190 return errors::Internal(
"empty example list");
193 auto iter = std::find(output_names.begin(), output_names.end(),
194 kOutputPlusOneClassTensor);
195 const float offset = iter == output_names.end() ? 0 : 1;
197 const int batch_size = examples.size();
198 const int num_classes = FeatureSize(GetFeature(examples[0], kClassFeature));
199 *classes = Tensor(DT_STRING, TensorShape({batch_size, num_classes}));
200 *scores = Tensor(DT_FLOAT, TensorShape({batch_size, num_classes}));
201 auto classes_matrix = classes->matrix<tstring>();
202 auto scores_matrix = scores->matrix<
float>();
204 for (
int i = 0; i < batch_size; ++i) {
205 const Feature classes_feature = GetFeature(examples[i], kClassFeature);
206 if (FeatureSize(classes_feature) != num_classes) {
207 return errors::Internal(
"incorrect number of classes in feature: ",
208 classes_feature.DebugString());
210 const Feature scores_feature = GetFeature(examples[i], kScoreFeature);
211 if (FeatureSize(scores_feature) != num_classes) {
212 return errors::Internal(
"incorrect number of scores in feature: ",
213 scores_feature.DebugString());
215 for (
int c = 0; c < num_classes; ++c) {
216 classes_matrix(i, c) = classes_feature.bytes_list().value(c);
217 scores_matrix(i, c) = scores_feature.float_list().value(c) + offset;
220 return absl::OkStatus();
224 const absl::optional<int64_t> expected_timeout_;
227 class ClassifierTest :
public ::testing::TestWithParam<bool> {
229 void SetUp()
override {
230 SetSignatureMethodNameCheckFeature(IsMethodNameCheckEnabled());
231 saved_model_bundle_.reset(
new SavedModelBundle);
232 meta_graph_def_ = &saved_model_bundle_->meta_graph_def;
233 absl::optional<int64_t> expected_timeout = GetRunOptions().timeout_in_ms();
234 fake_session_ =
new FakeSession(expected_timeout);
235 saved_model_bundle_->session.reset(fake_session_);
237 auto* signature_defs = meta_graph_def_->mutable_signature_def();
238 SignatureDef sig_def;
239 TensorInfo input_tensor_info;
240 input_tensor_info.set_name(kInputTensor);
241 (*sig_def.mutable_inputs())[kClassifyInputs] = input_tensor_info;
242 TensorInfo class_tensor_info;
243 class_tensor_info.set_name(kClassTensor);
244 (*sig_def.mutable_outputs())[kClassifyOutputClasses] = class_tensor_info;
245 TensorInfo scores_tensor_info;
246 scores_tensor_info.set_name(kScoreTensor);
247 (*sig_def.mutable_outputs())[kClassifyOutputScores] = scores_tensor_info;
248 if (IsMethodNameCheckEnabled())
249 sig_def.set_method_name(kClassifyMethodName);
250 (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
252 AddNamedSignatureToSavedModelBundle(
253 kInputTensor, kOutputPlusOneClassTensor, kOutputPlusOneSignature,
254 true , meta_graph_def_);
255 AddNamedSignatureToSavedModelBundle(
256 kInputTensor, kOutputPlusOneClassTensor, kInvalidNamedSignature,
257 false , meta_graph_def_);
260 AddNamedSignatureToSavedModelBundle(
261 kInputTensor, kImproperlySizedScoresTensor,
262 kImproperlySizedScoresSignature,
true ,
267 bool IsMethodNameCheckEnabled() {
return GetParam(); }
270 Example example(
const std::vector<std::pair<string, float>>& class_scores) {
271 Feature classes_feature;
272 Feature scores_feature;
273 for (
const auto& class_score : class_scores) {
274 classes_feature.mutable_bytes_list()->add_value(class_score.first);
275 scores_feature.mutable_float_list()->add_value(class_score.second);
278 auto* features = example.mutable_features()->mutable_feature();
279 (*features)[kClassFeature] = classes_feature;
280 (*features)[kScoreFeature] = scores_feature;
285 std::unique_ptr<SavedModelBundle> saved_model(
new SavedModelBundle);
286 saved_model->meta_graph_def = saved_model_bundle_->meta_graph_def;
287 saved_model->session = std::move(saved_model_bundle_->session);
288 return CreateClassifierFromSavedModelBundle(
289 GetRunOptions(), std::move(saved_model), &classifier_);
292 RunOptions GetRunOptions()
const {
293 RunOptions run_options;
294 run_options.set_timeout_in_ms(42);
301 void AddNamedSignatureToSavedModelBundle(
302 const string& input_tensor_name,
const string& output_scores_tensor_name,
303 const string& signature_name,
const bool is_classification,
304 tensorflow::MetaGraphDef* meta_graph_def) {
305 auto* signature_defs = meta_graph_def->mutable_signature_def();
306 SignatureDef sig_def;
307 TensorInfo input_tensor_info;
308 input_tensor_info.set_name(input_tensor_name);
310 (*sig_def.mutable_inputs())[kClassifyInputs] = input_tensor_info;
311 if (is_classification) {
312 TensorInfo scores_tensor_info;
313 scores_tensor_info.set_name(output_scores_tensor_name);
314 (*sig_def.mutable_outputs())[kClassifyOutputScores] = scores_tensor_info;
315 TensorInfo class_tensor_info;
316 class_tensor_info.set_name(kClassTensor);
317 (*sig_def.mutable_outputs())[kClassifyOutputClasses] = class_tensor_info;
318 method_name = kClassifyMethodName;
320 TensorInfo output_tensor_info;
321 output_tensor_info.set_name(output_scores_tensor_name);
322 (*sig_def.mutable_outputs())[kRegressOutputs] = output_tensor_info;
323 method_name = kRegressMethodName;
325 if (IsMethodNameCheckEnabled()) sig_def.set_method_name(method_name);
326 (*signature_defs)[signature_name] = sig_def;
330 tensorflow::MetaGraphDef* meta_graph_def_;
331 FakeSession* fake_session_;
332 std::unique_ptr<SavedModelBundle> saved_model_bundle_;
335 std::unique_ptr<ClassifierInterface> classifier_;
338 ClassificationRequest request_;
339 ClassificationResult result_;
342 TEST_P(ClassifierTest, ExampleList) {
343 TF_ASSERT_OK(Create());
345 request_.mutable_input()->mutable_example_list()->mutable_examples();
346 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
347 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
348 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
349 EXPECT_THAT(result_, EqualsProto(
" classifications { "
359 " classifications { "
370 ClassificationResponse response;
371 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
372 {}, fake_session_, request_, &response));
373 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
383 " classifications { "
395 TEST_P(ClassifierTest, ExampleListWithContext) {
396 TF_ASSERT_OK(Create());
397 auto* list_and_context =
398 request_.mutable_input()->mutable_example_list_with_context();
400 *list_and_context->mutable_context() = example({{
"dos", 2}, {
"uno", 1}});
402 list_and_context->add_examples();
403 list_and_context->add_examples();
404 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
405 EXPECT_THAT(result_, EqualsProto(
" classifications { "
415 " classifications { "
426 ClassificationResponse response;
427 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
428 {}, fake_session_, request_, &response));
429 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
439 " classifications { "
451 TEST_P(ClassifierTest, ExampleListWithContext_DuplicateFeatures) {
452 TF_ASSERT_OK(Create());
453 auto* list_and_context =
454 request_.mutable_input()->mutable_example_list_with_context();
456 *list_and_context->mutable_context() = example({{
"uno", 1}, {
"dos", 2}});
458 list_and_context->add_examples();
461 *list_and_context->add_examples() = example({{
"tres", 3}, {
"cuatro", 4}});
462 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
463 EXPECT_THAT(result_, EqualsProto(
" classifications { "
473 " classifications { "
484 ClassificationResponse response;
485 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
486 {}, fake_session_, request_, &response));
487 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
497 " classifications { "
509 TEST_P(ClassifierTest, ClassesOnly) {
510 auto* signature_defs = meta_graph_def_->mutable_signature_def();
511 (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
512 kClassifyOutputScores);
513 TF_ASSERT_OK(Create());
515 request_.mutable_input()->mutable_example_list()->mutable_examples();
516 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
517 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
518 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
519 EXPECT_THAT(result_, EqualsProto(
" classifications { "
527 " classifications { "
536 ClassificationResponse response;
537 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
538 {}, fake_session_, request_, &response));
539 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
547 " classifications { "
557 TEST_P(ClassifierTest, ScoresOnly) {
558 auto* signature_defs = meta_graph_def_->mutable_signature_def();
559 (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
560 kClassifyOutputClasses);
562 TF_ASSERT_OK(Create());
564 request_.mutable_input()->mutable_example_list()->mutable_examples();
565 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
566 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
567 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
568 EXPECT_THAT(result_, EqualsProto(
" classifications { "
576 " classifications { "
585 ClassificationResponse response;
586 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
587 {}, fake_session_, request_, &response));
588 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
596 " classifications { "
606 TEST_P(ClassifierTest, ZeroScoresArePresent) {
607 auto* signature_defs = meta_graph_def_->mutable_signature_def();
608 (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
609 kClassifyOutputClasses);
610 TF_ASSERT_OK(Create());
612 request_.mutable_input()->mutable_example_list()->mutable_examples();
613 *examples->Add() = example({{
"minus", -1}, {
"zero", 0}, {
"one", 1}});
614 const std::vector<double> expected_outputs = {-1, 0, 1};
616 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
618 ASSERT_EQ(result_.classifications_size(), 1);
619 auto& classification = result_.classifications(0);
620 ASSERT_EQ(classification.classes_size(), 3);
622 for (
int i = 0; i < 3; ++i) {
623 EXPECT_NEAR(classification.classes(i).score(), expected_outputs[i], 1e-7);
626 ClassificationResponse response;
627 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
628 {}, fake_session_, request_, &response));
630 ASSERT_EQ(response.result().classifications_size(), 1);
631 auto& classification_resp = result_.classifications(0);
632 ASSERT_EQ(classification_resp.classes_size(), 3);
634 for (
int i = 0; i < 3; ++i) {
635 EXPECT_NEAR(classification_resp.classes(i).score(), expected_outputs[i],
640 TEST_P(ClassifierTest, ValidNamedSignature) {
641 TF_ASSERT_OK(Create());
642 request_.mutable_model_spec()->set_signature_name(kOutputPlusOneSignature);
644 request_.mutable_input()->mutable_example_list()->mutable_examples();
645 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
646 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
647 TF_ASSERT_OK(classifier_->Classify(request_, &result_));
649 EXPECT_THAT(result_, EqualsProto(
" classifications { "
659 " classifications { "
670 ClassificationResponse response;
671 TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
672 {}, fake_session_, request_, &response));
673 EXPECT_THAT(response.result(), EqualsProto(
" classifications { "
683 " classifications { "
695 TEST_P(ClassifierTest, InvalidNamedSignature) {
696 TF_ASSERT_OK(Create());
697 request_.mutable_model_spec()->set_signature_name(kInvalidNamedSignature);
699 request_.mutable_input()->mutable_example_list()->mutable_examples();
700 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
701 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
702 Status status = classifier_->Classify(request_, &result_);
704 ASSERT_FALSE(status.ok());
705 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
709 ClassificationResponse response;
710 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
711 fake_session_, request_, &response);
712 ASSERT_FALSE(status.ok());
713 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
718 TEST_P(ClassifierTest, MalformedScores) {
719 TF_ASSERT_OK(Create());
720 request_.mutable_model_spec()->set_signature_name(
721 kImproperlySizedScoresSignature);
723 request_.mutable_input()->mutable_example_list()->mutable_examples();
724 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
725 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
726 Status status = classifier_->Classify(request_, &result_);
728 ASSERT_FALSE(status.ok());
729 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
733 ClassificationResponse response;
734 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
735 fake_session_, request_, &response);
736 ASSERT_FALSE(status.ok());
737 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
742 TEST_P(ClassifierTest, MissingClassificationSignature) {
743 auto* signature_defs = meta_graph_def_->mutable_signature_def();
744 SignatureDef sig_def;
745 (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
746 TF_ASSERT_OK(Create());
748 request_.mutable_input()->mutable_example_list()->mutable_examples();
749 *examples->Add() = example({{
"dos", 2}});
751 Status status = classifier_->Classify(request_, &result_);
752 ASSERT_FALSE(status.ok());
753 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
757 ClassificationResponse response;
758 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
759 fake_session_, request_, &response);
760 ASSERT_FALSE(status.ok());
761 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
766 TEST_P(ClassifierTest, EmptyInput) {
767 TF_ASSERT_OK(Create());
769 request_.mutable_input();
770 Status status = classifier_->Classify(request_, &result_);
771 ASSERT_FALSE(status.ok());
772 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
773 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
775 ClassificationResponse response;
776 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
777 fake_session_, request_, &response);
778 ASSERT_FALSE(status.ok());
779 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
780 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
783 TEST_P(ClassifierTest, EmptyExampleList) {
784 TF_ASSERT_OK(Create());
786 request_.mutable_input()->mutable_example_list();
787 Status status = classifier_->Classify(request_, &result_);
788 ASSERT_FALSE(status.ok());
789 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
790 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
792 ClassificationResponse response;
793 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
794 fake_session_, request_, &response);
795 ASSERT_FALSE(status.ok());
796 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
797 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
800 TEST_P(ClassifierTest, EmptyExampleListWithContext) {
801 TF_ASSERT_OK(Create());
803 *request_.mutable_input()
804 ->mutable_example_list_with_context()
805 ->mutable_context() = example({{
"dos", 2}});
806 Status status = classifier_->Classify(request_, &result_);
807 ASSERT_FALSE(status.ok());
808 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
809 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
811 ClassificationResponse response;
812 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
813 fake_session_, request_, &response);
814 ASSERT_FALSE(status.ok());
815 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
816 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
819 TEST_P(ClassifierTest, RunsFails) {
820 MockSession* mock =
new MockSession;
821 saved_model_bundle_->session.reset(mock);
822 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
824 ::testing::Return(errors::Internal(
"Run totally failed")));
825 TF_ASSERT_OK(Create());
827 request_.mutable_input()->mutable_example_list()->mutable_examples();
828 *examples->Add() = example({{
"dos", 2}});
829 Status status = classifier_->Classify(request_, &result_);
830 ASSERT_FALSE(status.ok());
831 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"Run totally failed"));
833 ClassificationResponse response;
834 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
835 mock, request_, &response);
836 ASSERT_FALSE(status.ok());
837 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"Run totally failed"));
840 TEST_P(ClassifierTest, ClassesIncorrectTensorBatchSize) {
841 MockSession* mock =
new MockSession;
842 saved_model_bundle_->session.reset(mock);
844 Tensor classes(DT_STRING, TensorShape({1, 2}));
845 Tensor scores(DT_FLOAT, TensorShape({2, 2}));
846 std::vector<Tensor> outputs = {classes, scores};
847 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
848 .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
849 ::testing::Return(absl::OkStatus())));
850 TF_ASSERT_OK(Create());
852 request_.mutable_input()->mutable_example_list()->mutable_examples();
853 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
854 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
856 Status status = classifier_->Classify(request_, &result_);
857 ASSERT_FALSE(status.ok());
858 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"batch size"));
860 ClassificationResponse response;
861 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
862 mock, request_, &response);
863 ASSERT_FALSE(status.ok());
864 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"batch size"));
867 TEST_P(ClassifierTest, ClassesIncorrectTensorType) {
868 MockSession* mock =
new MockSession;
869 saved_model_bundle_->session.reset(mock);
872 Tensor classes(DT_FLOAT, TensorShape({2, 2}));
873 Tensor scores(DT_FLOAT, TensorShape({2, 2}));
874 std::vector<Tensor> outputs = {classes, scores};
875 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
876 .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
877 ::testing::Return(absl::OkStatus())));
878 TF_ASSERT_OK(Create());
880 request_.mutable_input()->mutable_example_list()->mutable_examples();
881 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
882 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
884 Status status = classifier_->Classify(request_, &result_);
885 ASSERT_FALSE(status.ok());
886 EXPECT_THAT(status.ToString(),
887 ::testing::HasSubstr(
"Expected classes Tensor of DT_STRING"));
888 ClassificationResponse response;
889 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
890 mock, request_, &response);
891 ASSERT_FALSE(status.ok());
892 EXPECT_THAT(status.ToString(),
893 ::testing::HasSubstr(
"Expected classes Tensor of DT_STRING"));
896 TEST_P(ClassifierTest, ScoresIncorrectTensorBatchSize) {
897 MockSession* mock =
new MockSession;
898 saved_model_bundle_->session.reset(mock);
899 Tensor classes(DT_STRING, TensorShape({2, 2}));
901 Tensor scores(DT_FLOAT, TensorShape({1, 2}));
902 std::vector<Tensor> outputs = {classes, scores};
903 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
904 .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
905 ::testing::Return(absl::OkStatus())));
906 TF_ASSERT_OK(Create());
908 request_.mutable_input()->mutable_example_list()->mutable_examples();
909 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
910 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
912 Status status = classifier_->Classify(request_, &result_);
913 ASSERT_FALSE(status.ok());
914 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"batch size"));
916 ClassificationResponse response;
917 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
918 mock, request_, &response);
919 ASSERT_FALSE(status.ok());
920 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"batch size"));
923 TEST_P(ClassifierTest, ScoresIncorrectTensorType) {
924 MockSession* mock =
new MockSession;
925 saved_model_bundle_->session.reset(mock);
926 Tensor classes(DT_STRING, TensorShape({2, 2}));
928 Tensor scores(DT_STRING, TensorShape({2, 2}));
929 std::vector<Tensor> outputs = {classes, scores};
930 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
931 .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
932 ::testing::Return(absl::OkStatus())));
933 TF_ASSERT_OK(Create());
935 request_.mutable_input()->mutable_example_list()->mutable_examples();
936 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
937 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
939 Status status = classifier_->Classify(request_, &result_);
940 ASSERT_FALSE(status.ok());
941 EXPECT_THAT(status.ToString(),
942 ::testing::HasSubstr(
"Expected scores Tensor of DT_FLOAT"));
944 ClassificationResponse response;
945 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
946 mock, request_, &response);
947 ASSERT_FALSE(status.ok());
948 EXPECT_THAT(status.ToString(),
949 ::testing::HasSubstr(
"Expected scores Tensor of DT_FLOAT"));
952 TEST_P(ClassifierTest, MismatchedNumberOfTensorClasses) {
953 MockSession* mock =
new MockSession;
954 saved_model_bundle_->session.reset(mock);
955 Tensor classes(DT_STRING, TensorShape({2, 2}));
957 Tensor scores(DT_FLOAT, TensorShape({2, 3}));
958 std::vector<Tensor> outputs = {classes, scores};
959 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
960 .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
961 ::testing::Return(absl::OkStatus())));
962 TF_ASSERT_OK(Create());
964 request_.mutable_input()->mutable_example_list()->mutable_examples();
965 *examples->Add() = example({{
"dos", 2}, {
"uno", 1}});
966 *examples->Add() = example({{
"cuatro", 4}, {
"tres", 3}});
968 Status status = classifier_->Classify(request_, &result_);
969 ASSERT_FALSE(status.ok());
972 ::testing::HasSubstr(
973 "Tensors class and score should match in dim_size(1). Got 2 vs. 3"));
975 ClassificationResponse response;
976 status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
977 mock, request_, &response);
978 ASSERT_FALSE(status.ok());
979 EXPECT_THAT(status.ToString(),
980 ::testing::HasSubstr(
"Tensors class and score should match in "
981 "dim_size(1). Got 2 vs. 3"));
984 TEST_P(ClassifierTest, MethodNameCheck) {
985 ClassificationResponse response;
986 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
987 example({{
"dos", 2}, {
"uno", 1}});
988 auto* signature_defs = meta_graph_def_->mutable_signature_def();
991 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
992 kClassifyMethodName);
993 TF_EXPECT_OK(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
994 request_, &response));
997 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
998 "not/supported/method");
999 EXPECT_EQ(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
1000 request_, &response)
1002 !IsMethodNameCheckEnabled());
1005 (*signature_defs)[kDefaultServingSignatureDefKey].clear_method_name();
1006 EXPECT_EQ(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
1007 request_, &response)
1009 !IsMethodNameCheckEnabled());
1012 INSTANTIATE_TEST_SUITE_P(Classifier, ClassifierTest, ::testing::Bool());