TensorFlow Serving C++ API Documentation
classifier_test.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/classifier.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "google/protobuf/map.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/example/feature.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/threadpool_options.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/public/session.h"
36 #include "tensorflow_serving/apis/classification.pb.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/core/test_util/mock_session.h"
40 #include "tensorflow_serving/servables/tensorflow/util.h"
41 #include "tensorflow_serving/test_util/test_util.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 namespace {
46 
47 using test_util::EqualsProto;
48 using test_util::MockSession;
49 using ::testing::_;
50 
51 const char kInputTensor[] = "input:0";
52 const char kClassTensor[] = "output:0";
53 const char kOutputPlusOneClassTensor[] = "outputPlusOne:0";
54 const char kClassFeature[] = "class";
55 const char kScoreTensor[] = "score:0";
56 const char kScoreFeature[] = "score";
57 const char kImproperlySizedScoresTensor[] = "ImproperlySizedScores:0";
58 
59 const char kOutputPlusOneSignature[] = "output_plus_one";
60 const char kInvalidNamedSignature[] = "invalid_regression_signature";
61 const char kImproperlySizedScoresSignature[] = "ImproperlySizedScoresSignature";
62 
63 // Fake Session used for testing TensorFlowClassifier.
64 // Assumes the input Tensor "input:0" has serialized tensorflow::Example values.
65 // Copies the "class" bytes feature from each Example to be the classification
66 // class for that example.
67 class FakeSession : public tensorflow::Session {
68  public:
69  explicit FakeSession(absl::optional<int64_t> expected_timeout)
70  : expected_timeout_(expected_timeout) {}
71  ~FakeSession() override = default;
72  Status Create(const GraphDef& graph) override {
73  return errors::Unimplemented("not available in fake");
74  }
75  Status Extend(const GraphDef& graph) override {
76  return errors::Unimplemented("not available in fake");
77  }
78 
79  Status Close() override {
80  return errors::Unimplemented("not available in fake");
81  }
82 
83  Status ListDevices(std::vector<DeviceAttributes>* response) override {
84  return errors::Unimplemented("not available in fake");
85  }
86 
87  Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
88  const std::vector<string>& output_names,
89  const std::vector<string>& target_nodes,
90  std::vector<Tensor>* outputs) override {
91  if (expected_timeout_) {
92  LOG(FATAL) << "Run() without RunOptions not expected to be called";
93  }
94  RunMetadata run_metadata;
95  return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
96  &run_metadata);
97  }
98 
99  Status Run(const RunOptions& run_options,
100  const std::vector<std::pair<string, Tensor>>& inputs,
101  const std::vector<string>& output_names,
102  const std::vector<string>& target_nodes,
103  std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
104  return Run(run_options, inputs, output_names, target_nodes, outputs,
105  run_metadata, thread::ThreadPoolOptions());
106  }
107 
108  Status Run(const RunOptions& run_options,
109  const std::vector<std::pair<string, Tensor>>& inputs,
110  const std::vector<string>& output_names,
111  const std::vector<string>& target_nodes,
112  std::vector<Tensor>* outputs, RunMetadata* run_metadata,
113  const thread::ThreadPoolOptions& thread_pool_options) override {
114  if (expected_timeout_) {
115  CHECK_EQ(*expected_timeout_, run_options.timeout_in_ms());
116  }
117  if (inputs.size() != 1 || inputs[0].first != kInputTensor) {
118  return errors::Internal("Expected one input Tensor.");
119  }
120 
121  const Tensor& input = inputs[0].second;
122  std::vector<Example> examples;
123  TF_RETURN_IF_ERROR(GetExamples(input, &examples));
124  Tensor classes;
125  Tensor scores;
126  TF_RETURN_IF_ERROR(
127  GetClassTensor(examples, output_names, &classes, &scores));
128  for (const auto& output_name : output_names) {
129  if (output_name == kClassTensor) {
130  outputs->push_back(classes);
131  } else if (output_name == kScoreTensor ||
132  output_name == kOutputPlusOneClassTensor) {
133  outputs->push_back(scores);
134  } else if (output_name == kImproperlySizedScoresTensor) {
135  // Insert a rank 3 tensor which should be an error because scores are
136  // expected to be rank 2.
137  outputs->emplace_back(DT_FLOAT, TensorShape({scores.dim_size(0),
138  scores.dim_size(1), 10}));
139  }
140  }
141 
142  return absl::OkStatus();
143  }
144 
145  // Parses TensorFlow Examples from a string Tensor.
146  static Status GetExamples(const Tensor& input,
147  std::vector<Example>* examples) {
148  examples->clear();
149  const int batch_size = input.dim_size(0);
150  const auto& flat_input = input.flat<tstring>();
151  for (int i = 0; i < batch_size; ++i) {
152  Example example;
153  if (!example.ParseFromArray(flat_input(i).data(), flat_input(i).size())) {
154  return errors::Internal("failed to parse example");
155  }
156  examples->push_back(example);
157  }
158  return absl::OkStatus();
159  }
160 
161  // Gets the Feature from an Example with the given name. Returns empty
162  // Feature if the name does not exist.
163  static Feature GetFeature(const Example& example, const string& name) {
164  const auto it = example.features().feature().find(name);
165  if (it != example.features().feature().end()) {
166  return it->second;
167  }
168  return Feature();
169  }
170 
171  // Returns the number of individual elements in a Feature.
172  static int FeatureSize(const Feature& feature) {
173  if (feature.has_float_list()) {
174  return feature.float_list().value_size();
175  } else if (feature.has_int64_list()) {
176  return feature.int64_list().value_size();
177  } else if (feature.has_bytes_list()) {
178  return feature.bytes_list().value_size();
179  }
180  return 0;
181  }
182 
183  // Creates a Tensor by copying the "class" feature from each Example.
184  // Requires each Example have an bytes feature called "class" which is of the
185  // same non-zero length.
186  static Status GetClassTensor(const std::vector<Example>& examples,
187  const std::vector<string>& output_names,
188  Tensor* classes, Tensor* scores) {
189  if (examples.empty()) {
190  return errors::Internal("empty example list");
191  }
192 
193  auto iter = std::find(output_names.begin(), output_names.end(),
194  kOutputPlusOneClassTensor);
195  const float offset = iter == output_names.end() ? 0 : 1;
196 
197  const int batch_size = examples.size();
198  const int num_classes = FeatureSize(GetFeature(examples[0], kClassFeature));
199  *classes = Tensor(DT_STRING, TensorShape({batch_size, num_classes}));
200  *scores = Tensor(DT_FLOAT, TensorShape({batch_size, num_classes}));
201  auto classes_matrix = classes->matrix<tstring>();
202  auto scores_matrix = scores->matrix<float>();
203 
204  for (int i = 0; i < batch_size; ++i) {
205  const Feature classes_feature = GetFeature(examples[i], kClassFeature);
206  if (FeatureSize(classes_feature) != num_classes) {
207  return errors::Internal("incorrect number of classes in feature: ",
208  classes_feature.DebugString());
209  }
210  const Feature scores_feature = GetFeature(examples[i], kScoreFeature);
211  if (FeatureSize(scores_feature) != num_classes) {
212  return errors::Internal("incorrect number of scores in feature: ",
213  scores_feature.DebugString());
214  }
215  for (int c = 0; c < num_classes; ++c) {
216  classes_matrix(i, c) = classes_feature.bytes_list().value(c);
217  scores_matrix(i, c) = scores_feature.float_list().value(c) + offset;
218  }
219  }
220  return absl::OkStatus();
221  }
222 
223  private:
224  const absl::optional<int64_t> expected_timeout_;
225 };
226 
227 class ClassifierTest : public ::testing::TestWithParam<bool> {
228  public:
229  void SetUp() override {
230  SetSignatureMethodNameCheckFeature(IsMethodNameCheckEnabled());
231  saved_model_bundle_.reset(new SavedModelBundle);
232  meta_graph_def_ = &saved_model_bundle_->meta_graph_def;
233  absl::optional<int64_t> expected_timeout = GetRunOptions().timeout_in_ms();
234  fake_session_ = new FakeSession(expected_timeout);
235  saved_model_bundle_->session.reset(fake_session_);
236 
237  auto* signature_defs = meta_graph_def_->mutable_signature_def();
238  SignatureDef sig_def;
239  TensorInfo input_tensor_info;
240  input_tensor_info.set_name(kInputTensor);
241  (*sig_def.mutable_inputs())[kClassifyInputs] = input_tensor_info;
242  TensorInfo class_tensor_info;
243  class_tensor_info.set_name(kClassTensor);
244  (*sig_def.mutable_outputs())[kClassifyOutputClasses] = class_tensor_info;
245  TensorInfo scores_tensor_info;
246  scores_tensor_info.set_name(kScoreTensor);
247  (*sig_def.mutable_outputs())[kClassifyOutputScores] = scores_tensor_info;
248  if (IsMethodNameCheckEnabled())
249  sig_def.set_method_name(kClassifyMethodName);
250  (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
251 
252  AddNamedSignatureToSavedModelBundle(
253  kInputTensor, kOutputPlusOneClassTensor, kOutputPlusOneSignature,
254  true /* is_classification */, meta_graph_def_);
255  AddNamedSignatureToSavedModelBundle(
256  kInputTensor, kOutputPlusOneClassTensor, kInvalidNamedSignature,
257  false /* is_classification */, meta_graph_def_);
258 
259  // Add a named signature where the output is not valid.
260  AddNamedSignatureToSavedModelBundle(
261  kInputTensor, kImproperlySizedScoresTensor,
262  kImproperlySizedScoresSignature, true /* is_classification */,
263  meta_graph_def_);
264  }
265 
266  protected:
267  bool IsMethodNameCheckEnabled() { return GetParam(); }
268 
269  // Return an example with the feature "output" = [output].
270  Example example(const std::vector<std::pair<string, float>>& class_scores) {
271  Feature classes_feature;
272  Feature scores_feature;
273  for (const auto& class_score : class_scores) {
274  classes_feature.mutable_bytes_list()->add_value(class_score.first);
275  scores_feature.mutable_float_list()->add_value(class_score.second);
276  }
277  Example example;
278  auto* features = example.mutable_features()->mutable_feature();
279  (*features)[kClassFeature] = classes_feature;
280  (*features)[kScoreFeature] = scores_feature;
281  return example;
282  }
283 
284  Status Create() {
285  std::unique_ptr<SavedModelBundle> saved_model(new SavedModelBundle);
286  saved_model->meta_graph_def = saved_model_bundle_->meta_graph_def;
287  saved_model->session = std::move(saved_model_bundle_->session);
288  return CreateClassifierFromSavedModelBundle(
289  GetRunOptions(), std::move(saved_model), &classifier_);
290  }
291 
292  RunOptions GetRunOptions() const {
293  RunOptions run_options;
294  run_options.set_timeout_in_ms(42);
295  return run_options;
296  }
297 
298  // Add a named signature to the mutable meta_graph_def* parameter.
299  // If is_classification is false, will add a regression signature, which is
300  // invalid in classification requests.
301  void AddNamedSignatureToSavedModelBundle(
302  const string& input_tensor_name, const string& output_scores_tensor_name,
303  const string& signature_name, const bool is_classification,
304  tensorflow::MetaGraphDef* meta_graph_def) {
305  auto* signature_defs = meta_graph_def->mutable_signature_def();
306  SignatureDef sig_def;
307  TensorInfo input_tensor_info;
308  input_tensor_info.set_name(input_tensor_name);
309  string method_name;
310  (*sig_def.mutable_inputs())[kClassifyInputs] = input_tensor_info;
311  if (is_classification) {
312  TensorInfo scores_tensor_info;
313  scores_tensor_info.set_name(output_scores_tensor_name);
314  (*sig_def.mutable_outputs())[kClassifyOutputScores] = scores_tensor_info;
315  TensorInfo class_tensor_info;
316  class_tensor_info.set_name(kClassTensor);
317  (*sig_def.mutable_outputs())[kClassifyOutputClasses] = class_tensor_info;
318  method_name = kClassifyMethodName;
319  } else {
320  TensorInfo output_tensor_info;
321  output_tensor_info.set_name(output_scores_tensor_name);
322  (*sig_def.mutable_outputs())[kRegressOutputs] = output_tensor_info;
323  method_name = kRegressMethodName;
324  }
325  if (IsMethodNameCheckEnabled()) sig_def.set_method_name(method_name);
326  (*signature_defs)[signature_name] = sig_def;
327  }
328 
329  // Variables used to create the classifier.
330  tensorflow::MetaGraphDef* meta_graph_def_;
331  FakeSession* fake_session_;
332  std::unique_ptr<SavedModelBundle> saved_model_bundle_;
333 
334  // Classifier valid after calling create.
335  std::unique_ptr<ClassifierInterface> classifier_;
336 
337  // Convenience variables.
338  ClassificationRequest request_;
339  ClassificationResult result_;
340 };
341 
342 TEST_P(ClassifierTest, ExampleList) {
343  TF_ASSERT_OK(Create());
344  auto* examples =
345  request_.mutable_input()->mutable_example_list()->mutable_examples();
346  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
347  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
348  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
349  EXPECT_THAT(result_, EqualsProto(" classifications { "
350  " classes { "
351  " label: 'dos' "
352  " score: 2 "
353  " } "
354  " classes { "
355  " label: 'uno' "
356  " score: 1 "
357  " } "
358  " } "
359  " classifications { "
360  " classes { "
361  " label: 'cuatro' "
362  " score: 4 "
363  " } "
364  " classes { "
365  " label: 'tres' "
366  " score: 3 "
367  " } "
368  " } "));
369  // Test RunClassify
370  ClassificationResponse response;
371  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
372  {}, fake_session_, request_, &response));
373  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
374  " classes { "
375  " label: 'dos' "
376  " score: 2 "
377  " } "
378  " classes { "
379  " label: 'uno' "
380  " score: 1 "
381  " } "
382  " } "
383  " classifications { "
384  " classes { "
385  " label: 'cuatro' "
386  " score: 4 "
387  " } "
388  " classes { "
389  " label: 'tres' "
390  " score: 3 "
391  " } "
392  " } "));
393 }
394 
395 TEST_P(ClassifierTest, ExampleListWithContext) {
396  TF_ASSERT_OK(Create());
397  auto* list_and_context =
398  request_.mutable_input()->mutable_example_list_with_context();
399  // Context gets copied to each example.
400  *list_and_context->mutable_context() = example({{"dos", 2}, {"uno", 1}});
401  // Add empty examples to recieve the context.
402  list_and_context->add_examples();
403  list_and_context->add_examples();
404  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
405  EXPECT_THAT(result_, EqualsProto(" classifications { "
406  " classes { "
407  " label: 'dos' "
408  " score: 2 "
409  " } "
410  " classes { "
411  " label: 'uno' "
412  " score: 1 "
413  " } "
414  " } "
415  " classifications { "
416  " classes { "
417  " label: 'dos' "
418  " score: 2 "
419  " } "
420  " classes { "
421  " label: 'uno' "
422  " score: 1 "
423  " } "
424  " } "));
425 
426  ClassificationResponse response;
427  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
428  {}, fake_session_, request_, &response));
429  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
430  " classes { "
431  " label: 'dos' "
432  " score: 2 "
433  " } "
434  " classes { "
435  " label: 'uno' "
436  " score: 1 "
437  " } "
438  " } "
439  " classifications { "
440  " classes { "
441  " label: 'dos' "
442  " score: 2 "
443  " } "
444  " classes { "
445  " label: 'uno' "
446  " score: 1 "
447  " } "
448  " } "));
449 }
450 
451 TEST_P(ClassifierTest, ExampleListWithContext_DuplicateFeatures) {
452  TF_ASSERT_OK(Create());
453  auto* list_and_context =
454  request_.mutable_input()->mutable_example_list_with_context();
455  // Context gets copied to each example.
456  *list_and_context->mutable_context() = example({{"uno", 1}, {"dos", 2}});
457  // Add an empty example, after merge it should be equal to the context.
458  list_and_context->add_examples();
459  // Add an example with a duplicate feature. Technically this behavior is
460  // undefined so here we are ensuring we don't crash.
461  *list_and_context->add_examples() = example({{"tres", 3}, {"cuatro", 4}});
462  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
463  EXPECT_THAT(result_, EqualsProto(" classifications { "
464  " classes { "
465  " label: 'uno' "
466  " score: 1 "
467  " } "
468  " classes { "
469  " label: 'dos' "
470  " score: 2 "
471  " } "
472  " } "
473  " classifications { "
474  " classes { "
475  " label: 'tres' "
476  " score: 3 "
477  " } "
478  " classes { "
479  " label: 'cuatro' "
480  " score: 4 "
481  " } "
482  " } "));
483 
484  ClassificationResponse response;
485  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
486  {}, fake_session_, request_, &response));
487  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
488  " classes { "
489  " label: 'uno' "
490  " score: 1 "
491  " } "
492  " classes { "
493  " label: 'dos' "
494  " score: 2 "
495  " } "
496  " } "
497  " classifications { "
498  " classes { "
499  " label: 'tres' "
500  " score: 3 "
501  " } "
502  " classes { "
503  " label: 'cuatro' "
504  " score: 4 "
505  " } "
506  " } "));
507 }
508 
509 TEST_P(ClassifierTest, ClassesOnly) {
510  auto* signature_defs = meta_graph_def_->mutable_signature_def();
511  (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
512  kClassifyOutputScores);
513  TF_ASSERT_OK(Create());
514  auto* examples =
515  request_.mutable_input()->mutable_example_list()->mutable_examples();
516  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
517  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
518  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
519  EXPECT_THAT(result_, EqualsProto(" classifications { "
520  " classes { "
521  " label: 'dos' "
522  " } "
523  " classes { "
524  " label: 'uno' "
525  " } "
526  " } "
527  " classifications { "
528  " classes { "
529  " label: 'cuatro' "
530  " } "
531  " classes { "
532  " label: 'tres' "
533  " } "
534  " } "));
535 
536  ClassificationResponse response;
537  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
538  {}, fake_session_, request_, &response));
539  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
540  " classes { "
541  " label: 'dos' "
542  " } "
543  " classes { "
544  " label: 'uno' "
545  " } "
546  " } "
547  " classifications { "
548  " classes { "
549  " label: 'cuatro' "
550  " } "
551  " classes { "
552  " label: 'tres' "
553  " } "
554  " } "));
555 }
556 
557 TEST_P(ClassifierTest, ScoresOnly) {
558  auto* signature_defs = meta_graph_def_->mutable_signature_def();
559  (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
560  kClassifyOutputClasses);
561 
562  TF_ASSERT_OK(Create());
563  auto* examples =
564  request_.mutable_input()->mutable_example_list()->mutable_examples();
565  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
566  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
567  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
568  EXPECT_THAT(result_, EqualsProto(" classifications { "
569  " classes { "
570  " score: 2 "
571  " } "
572  " classes { "
573  " score: 1 "
574  " } "
575  " } "
576  " classifications { "
577  " classes { "
578  " score: 4 "
579  " } "
580  " classes { "
581  " score: 3 "
582  " } "
583  " } "));
584 
585  ClassificationResponse response;
586  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
587  {}, fake_session_, request_, &response));
588  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
589  " classes { "
590  " score: 2 "
591  " } "
592  " classes { "
593  " score: 1 "
594  " } "
595  " } "
596  " classifications { "
597  " classes { "
598  " score: 4 "
599  " } "
600  " classes { "
601  " score: 3 "
602  " } "
603  " } "));
604 }
605 
606 TEST_P(ClassifierTest, ZeroScoresArePresent) {
607  auto* signature_defs = meta_graph_def_->mutable_signature_def();
608  (*signature_defs)[kDefaultServingSignatureDefKey].mutable_outputs()->erase(
609  kClassifyOutputClasses);
610  TF_ASSERT_OK(Create());
611  auto* examples =
612  request_.mutable_input()->mutable_example_list()->mutable_examples();
613  *examples->Add() = example({{"minus", -1}, {"zero", 0}, {"one", 1}});
614  const std::vector<double> expected_outputs = {-1, 0, 1};
615 
616  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
617  // Parse the protos and compare the results with expected scores.
618  ASSERT_EQ(result_.classifications_size(), 1);
619  auto& classification = result_.classifications(0);
620  ASSERT_EQ(classification.classes_size(), 3);
621 
622  for (int i = 0; i < 3; ++i) {
623  EXPECT_NEAR(classification.classes(i).score(), expected_outputs[i], 1e-7);
624  }
625 
626  ClassificationResponse response;
627  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
628  {}, fake_session_, request_, &response));
629  // Parse the protos and compare the results with expected scores.
630  ASSERT_EQ(response.result().classifications_size(), 1);
631  auto& classification_resp = result_.classifications(0);
632  ASSERT_EQ(classification_resp.classes_size(), 3);
633 
634  for (int i = 0; i < 3; ++i) {
635  EXPECT_NEAR(classification_resp.classes(i).score(), expected_outputs[i],
636  1e-7);
637  }
638 }
639 
640 TEST_P(ClassifierTest, ValidNamedSignature) {
641  TF_ASSERT_OK(Create());
642  request_.mutable_model_spec()->set_signature_name(kOutputPlusOneSignature);
643  auto* examples =
644  request_.mutable_input()->mutable_example_list()->mutable_examples();
645  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
646  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
647  TF_ASSERT_OK(classifier_->Classify(request_, &result_));
648 
649  EXPECT_THAT(result_, EqualsProto(" classifications { "
650  " classes { "
651  " label: 'dos' "
652  " score: 3 "
653  " } "
654  " classes { "
655  " label: 'uno' "
656  " score: 2 "
657  " } "
658  " } "
659  " classifications { "
660  " classes { "
661  " label: 'cuatro' "
662  " score: 5 "
663  " } "
664  " classes { "
665  " label: 'tres' "
666  " score: 4 "
667  " } "
668  " } "));
669 
670  ClassificationResponse response;
671  TF_ASSERT_OK(RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def,
672  {}, fake_session_, request_, &response));
673  EXPECT_THAT(response.result(), EqualsProto(" classifications { "
674  " classes { "
675  " label: 'dos' "
676  " score: 3 "
677  " } "
678  " classes { "
679  " label: 'uno' "
680  " score: 2 "
681  " } "
682  " } "
683  " classifications { "
684  " classes { "
685  " label: 'cuatro' "
686  " score: 5 "
687  " } "
688  " classes { "
689  " label: 'tres' "
690  " score: 4 "
691  " } "
692  " } "));
693 }
694 
695 TEST_P(ClassifierTest, InvalidNamedSignature) {
696  TF_ASSERT_OK(Create());
697  request_.mutable_model_spec()->set_signature_name(kInvalidNamedSignature);
698  auto* examples =
699  request_.mutable_input()->mutable_example_list()->mutable_examples();
700  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
701  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
702  Status status = classifier_->Classify(request_, &result_);
703 
704  ASSERT_FALSE(status.ok());
705  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
706  status.code())
707  << status;
708 
709  ClassificationResponse response;
710  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
711  fake_session_, request_, &response);
712  ASSERT_FALSE(status.ok());
713  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
714  status.code())
715  << status;
716 }
717 
718 TEST_P(ClassifierTest, MalformedScores) {
719  TF_ASSERT_OK(Create());
720  request_.mutable_model_spec()->set_signature_name(
721  kImproperlySizedScoresSignature);
722  auto* examples =
723  request_.mutable_input()->mutable_example_list()->mutable_examples();
724  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
725  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
726  Status status = classifier_->Classify(request_, &result_);
727 
728  ASSERT_FALSE(status.ok());
729  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
730  status.code())
731  << status;
732 
733  ClassificationResponse response;
734  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
735  fake_session_, request_, &response);
736  ASSERT_FALSE(status.ok());
737  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
738  status.code())
739  << status;
740 }
741 
742 TEST_P(ClassifierTest, MissingClassificationSignature) {
743  auto* signature_defs = meta_graph_def_->mutable_signature_def();
744  SignatureDef sig_def;
745  (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
746  TF_ASSERT_OK(Create());
747  auto* examples =
748  request_.mutable_input()->mutable_example_list()->mutable_examples();
749  *examples->Add() = example({{"dos", 2}});
750  // TODO(b/26220896): This error should move to construction time.
751  Status status = classifier_->Classify(request_, &result_);
752  ASSERT_FALSE(status.ok());
753  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
754  status.code())
755  << status;
756 
757  ClassificationResponse response;
758  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
759  fake_session_, request_, &response);
760  ASSERT_FALSE(status.ok());
761  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
762  status.code())
763  << status;
764 }
765 
766 TEST_P(ClassifierTest, EmptyInput) {
767  TF_ASSERT_OK(Create());
768  // Touch input.
769  request_.mutable_input();
770  Status status = classifier_->Classify(request_, &result_);
771  ASSERT_FALSE(status.ok());
772  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
773  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
774 
775  ClassificationResponse response;
776  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
777  fake_session_, request_, &response);
778  ASSERT_FALSE(status.ok());
779  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
780  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
781 }
782 
783 TEST_P(ClassifierTest, EmptyExampleList) {
784  TF_ASSERT_OK(Create());
785  // Touch ExampleList.
786  request_.mutable_input()->mutable_example_list();
787  Status status = classifier_->Classify(request_, &result_);
788  ASSERT_FALSE(status.ok());
789  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
790  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
791 
792  ClassificationResponse response;
793  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
794  fake_session_, request_, &response);
795  ASSERT_FALSE(status.ok());
796  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
797  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
798 }
799 
800 TEST_P(ClassifierTest, EmptyExampleListWithContext) {
801  TF_ASSERT_OK(Create());
802  // Touch ExampleListWithContext, context populated but no Examples.
803  *request_.mutable_input()
804  ->mutable_example_list_with_context()
805  ->mutable_context() = example({{"dos", 2}});
806  Status status = classifier_->Classify(request_, &result_);
807  ASSERT_FALSE(status.ok());
808  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
809  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
810 
811  ClassificationResponse response;
812  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
813  fake_session_, request_, &response);
814  ASSERT_FALSE(status.ok());
815  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
816  EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
817 }
818 
819 TEST_P(ClassifierTest, RunsFails) {
820  MockSession* mock = new MockSession;
821  saved_model_bundle_->session.reset(mock);
822  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
823  .WillRepeatedly(
824  ::testing::Return(errors::Internal("Run totally failed")));
825  TF_ASSERT_OK(Create());
826  auto* examples =
827  request_.mutable_input()->mutable_example_list()->mutable_examples();
828  *examples->Add() = example({{"dos", 2}});
829  Status status = classifier_->Classify(request_, &result_);
830  ASSERT_FALSE(status.ok());
831  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Run totally failed"));
832 
833  ClassificationResponse response;
834  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
835  mock, request_, &response);
836  ASSERT_FALSE(status.ok());
837  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Run totally failed"));
838 }
839 
840 TEST_P(ClassifierTest, ClassesIncorrectTensorBatchSize) {
841  MockSession* mock = new MockSession;
842  saved_model_bundle_->session.reset(mock);
843  // This Tensor only has one batch item but we will have two inputs.
844  Tensor classes(DT_STRING, TensorShape({1, 2}));
845  Tensor scores(DT_FLOAT, TensorShape({2, 2}));
846  std::vector<Tensor> outputs = {classes, scores};
847  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
848  .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
849  ::testing::Return(absl::OkStatus())));
850  TF_ASSERT_OK(Create());
851  auto* examples =
852  request_.mutable_input()->mutable_example_list()->mutable_examples();
853  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
854  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
855 
856  Status status = classifier_->Classify(request_, &result_);
857  ASSERT_FALSE(status.ok());
858  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
859 
860  ClassificationResponse response;
861  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
862  mock, request_, &response);
863  ASSERT_FALSE(status.ok());
864  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
865 }
866 
867 TEST_P(ClassifierTest, ClassesIncorrectTensorType) {
868  MockSession* mock = new MockSession;
869  saved_model_bundle_->session.reset(mock);
870 
871  // This Tensor is the wrong type for class.
872  Tensor classes(DT_FLOAT, TensorShape({2, 2}));
873  Tensor scores(DT_FLOAT, TensorShape({2, 2}));
874  std::vector<Tensor> outputs = {classes, scores};
875  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
876  .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
877  ::testing::Return(absl::OkStatus())));
878  TF_ASSERT_OK(Create());
879  auto* examples =
880  request_.mutable_input()->mutable_example_list()->mutable_examples();
881  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
882  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
883 
884  Status status = classifier_->Classify(request_, &result_);
885  ASSERT_FALSE(status.ok());
886  EXPECT_THAT(status.ToString(),
887  ::testing::HasSubstr("Expected classes Tensor of DT_STRING"));
888  ClassificationResponse response;
889  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
890  mock, request_, &response);
891  ASSERT_FALSE(status.ok());
892  EXPECT_THAT(status.ToString(),
893  ::testing::HasSubstr("Expected classes Tensor of DT_STRING"));
894 }
895 
896 TEST_P(ClassifierTest, ScoresIncorrectTensorBatchSize) {
897  MockSession* mock = new MockSession;
898  saved_model_bundle_->session.reset(mock);
899  Tensor classes(DT_STRING, TensorShape({2, 2}));
900  // This Tensor only has one batch item but we will have two inputs.
901  Tensor scores(DT_FLOAT, TensorShape({1, 2}));
902  std::vector<Tensor> outputs = {classes, scores};
903  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
904  .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
905  ::testing::Return(absl::OkStatus())));
906  TF_ASSERT_OK(Create());
907  auto* examples =
908  request_.mutable_input()->mutable_example_list()->mutable_examples();
909  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
910  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
911 
912  Status status = classifier_->Classify(request_, &result_);
913  ASSERT_FALSE(status.ok());
914  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
915 
916  ClassificationResponse response;
917  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
918  mock, request_, &response);
919  ASSERT_FALSE(status.ok());
920  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
921 }
922 
923 TEST_P(ClassifierTest, ScoresIncorrectTensorType) {
924  MockSession* mock = new MockSession;
925  saved_model_bundle_->session.reset(mock);
926  Tensor classes(DT_STRING, TensorShape({2, 2}));
927  // This Tensor is the wrong type for class.
928  Tensor scores(DT_STRING, TensorShape({2, 2}));
929  std::vector<Tensor> outputs = {classes, scores};
930  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
931  .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
932  ::testing::Return(absl::OkStatus())));
933  TF_ASSERT_OK(Create());
934  auto* examples =
935  request_.mutable_input()->mutable_example_list()->mutable_examples();
936  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
937  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
938 
939  Status status = classifier_->Classify(request_, &result_);
940  ASSERT_FALSE(status.ok());
941  EXPECT_THAT(status.ToString(),
942  ::testing::HasSubstr("Expected scores Tensor of DT_FLOAT"));
943 
944  ClassificationResponse response;
945  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
946  mock, request_, &response);
947  ASSERT_FALSE(status.ok());
948  EXPECT_THAT(status.ToString(),
949  ::testing::HasSubstr("Expected scores Tensor of DT_FLOAT"));
950 }
951 
952 TEST_P(ClassifierTest, MismatchedNumberOfTensorClasses) {
953  MockSession* mock = new MockSession;
954  saved_model_bundle_->session.reset(mock);
955  Tensor classes(DT_STRING, TensorShape({2, 2}));
956  // Scores Tensor has three scores but classes only has two labels.
957  Tensor scores(DT_FLOAT, TensorShape({2, 3}));
958  std::vector<Tensor> outputs = {classes, scores};
959  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
960  .WillRepeatedly(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
961  ::testing::Return(absl::OkStatus())));
962  TF_ASSERT_OK(Create());
963  auto* examples =
964  request_.mutable_input()->mutable_example_list()->mutable_examples();
965  *examples->Add() = example({{"dos", 2}, {"uno", 1}});
966  *examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
967 
968  Status status = classifier_->Classify(request_, &result_);
969  ASSERT_FALSE(status.ok());
970  EXPECT_THAT(
971  status.ToString(),
972  ::testing::HasSubstr(
973  "Tensors class and score should match in dim_size(1). Got 2 vs. 3"));
974 
975  ClassificationResponse response;
976  status = RunClassify(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
977  mock, request_, &response);
978  ASSERT_FALSE(status.ok());
979  EXPECT_THAT(status.ToString(),
980  ::testing::HasSubstr("Tensors class and score should match in "
981  "dim_size(1). Got 2 vs. 3"));
982 }
983 
984 TEST_P(ClassifierTest, MethodNameCheck) {
985  ClassificationResponse response;
986  *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
987  example({{"dos", 2}, {"uno", 1}});
988  auto* signature_defs = meta_graph_def_->mutable_signature_def();
989 
990  // Legit method name. Should always work.
991  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
992  kClassifyMethodName);
993  TF_EXPECT_OK(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
994  request_, &response));
995 
996  // Unsupported method name will fail when method check is enabled.
997  (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
998  "not/supported/method");
999  EXPECT_EQ(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
1000  request_, &response)
1001  .ok(),
1002  !IsMethodNameCheckEnabled());
1003 
1004  // Empty method name will fail when method check is enabled.
1005  (*signature_defs)[kDefaultServingSignatureDefKey].clear_method_name();
1006  EXPECT_EQ(RunClassify(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
1007  request_, &response)
1008  .ok(),
1009  !IsMethodNameCheckEnabled());
1010 }
1011 
1012 INSTANTIATE_TEST_SUITE_P(Classifier, ClassifierTest, ::testing::Bool());
1013 
1014 } // namespace
1015 } // namespace serving
1016 } // namespace tensorflow