TensorFlow Serving C++ API Documentation
tfrt_classifier.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
17 
18 #include "tensorflow/cc/saved_model/signature_constants.h"
19 #include "tensorflow/core/example/example.pb.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/threadpool.h"
25 #include "tensorflow/core/platform/threadpool_options.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/tfrt/utils/tensor_util.h"
29 #include "tsl/platform/error_logging.h"
30 #include "tensorflow_serving/apis/classification.pb.h"
31 #include "tensorflow_serving/apis/classifier.h"
32 #include "tensorflow_serving/apis/input.pb.h"
33 #include "tensorflow_serving/apis/model.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/util.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 
39 Status PreProcessClassification(
40  const tfrt::FunctionMetadata& function_metadata) {
41  if (function_metadata.GetInputNames().size() != 1) {
42  return errors::InvalidArgument(
43  strings::StrCat("Expected one input Tensor."));
44  }
45  if (function_metadata.GetOutputNames().size() != 1 &&
46  function_metadata.GetOutputNames().size() != 2) {
47  return errors::InvalidArgument(
48  strings::StrCat("Expected one or two output Tensors, found ",
49  function_metadata.GetOutputNames().size()));
50  }
51 
52  if (function_metadata.GetInputNames()[0] != kClassifyInputs) {
53  return errors::FailedPrecondition(
54  "No classification inputs found in function's metadata, only "
55  "contains: ",
56  function_metadata.GetInputNames()[0]);
57  }
58 
59  bool find_output_classes = false;
60  bool find_output_scores = false;
61  for (const std::string& output_name : function_metadata.GetOutputNames()) {
62  if (output_name == kClassifyOutputClasses) {
63  find_output_classes = true;
64  } else if ((output_name == kClassifyOutputScores)) {
65  find_output_scores = true;
66  }
67  }
68 
69  if ((function_metadata.GetOutputNames().size() == 1 && !find_output_classes &&
70  !find_output_scores) ||
71  (function_metadata.GetOutputNames().size() == 2 &&
72  !(find_output_classes && find_output_scores))) {
73  return errors::FailedPrecondition(strings::StrCat(
74  "Expected classification function outputs to contain", "\"",
75  kClassifyOutputClasses, "\" and/or \"", kClassifyOutputScores, "\". "));
76  }
77 
78  return absl::OkStatus();
79 }
80 
81 Status PostProcessClassificationResult(
82  int num_examples, const std::vector<string>& output_names,
83  const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
84  if (output_tensors.size() != output_names.size()) {
85  return errors::InvalidArgument(strings::StrCat(
86  "Unexpected output tensors size. Expected ", output_names.size(),
87  " output tensor(s). Got: ", output_tensors.size()));
88  }
89 
90  const Tensor* classes = nullptr;
91  const Tensor* scores = nullptr;
92  for (int i = 0; i < output_tensors.size(); ++i) {
93  if (output_names[i] == kClassifyOutputClasses) {
94  classes = &output_tensors[i];
95  } else if (output_names[i] == kClassifyOutputScores) {
96  scores = &output_tensors[i];
97  }
98  }
99 
100  // Validate classes output Tensor.
101  if (classes) {
102  if (classes->dims() != 2) {
103  return errors::InvalidArgument(
104  "Expected Tensor shape: [batch_size num_classes] but got ",
105  classes->shape().DebugString());
106  }
107  if (classes->dtype() != DT_STRING) {
108  return errors::InvalidArgument(
109  "Expected classes Tensor of DT_STRING. Got: ",
110  DataType_Name(classes->dtype()));
111  }
112  if (classes->dim_size(0) != num_examples) {
113  return errors::InvalidArgument("Expected classes output batch size of ",
114  num_examples,
115  ". Got: ", classes->dim_size(0));
116  }
117  }
118  // Validate scores output Tensor.
119  if (scores) {
120  if (scores->dims() != 2) {
121  return errors::InvalidArgument(
122  "Expected Tensor shape: [batch_size num_classes] but got ",
123  scores->shape().DebugString());
124  }
125  if (scores->dtype() != DT_FLOAT) {
126  return errors::InvalidArgument(
127  "Expected scores Tensor of DT_FLOAT. Got: ",
128  DataType_Name(scores->dtype()));
129  }
130  if (scores->dim_size(0) != num_examples) {
131  return errors::InvalidArgument("Expected scores output batch size of ",
132  num_examples,
133  ". Got: ", scores->dim_size(0));
134  }
135  }
136  // Extract the number of classes from either the class or score output
137  // Tensor.
138  int num_classes = 0;
139  if (classes && scores) {
140  // If we have both Tensors they should agree in the second dimmension.
141  if (classes->dim_size(1) != scores->dim_size(1)) {
142  return errors::InvalidArgument(
143  "Tensors class and score should match in dim_size(1). Got ",
144  classes->dim_size(1), " vs. ", scores->dim_size(1));
145  }
146  num_classes = classes->dim_size(1);
147  } else if (classes) {
148  num_classes = classes->dim_size(1);
149  } else if (scores) {
150  num_classes = scores->dim_size(1);
151  }
152 
153  // Convert the output to ClassificationResult format.
154  for (int i = 0; i < num_examples; ++i) {
155  serving::Classifications* classifications = result->add_classifications();
156  for (int c = 0; c < num_classes; ++c) {
157  serving::Class* cl = classifications->add_classes();
158  if (classes) {
159  const tstring& class_tstr = (classes->matrix<tstring>())(i, c);
160  cl->set_label(class_tstr.data(), class_tstr.size());
161  }
162  if (scores) {
163  cl->set_score((scores->matrix<float>())(i, c));
164  }
165  }
166  }
167  return absl::OkStatus();
168 }
169 
170 Status RunClassify(const tfrt::SavedModel::RunOptions& run_options,
171  const absl::optional<int64_t>& servable_version,
172  tfrt::SavedModel* saved_model,
173  const ClassificationRequest& request,
174  ClassificationResponse* response) {
175  const string function_name = request.model_spec().signature_name().empty()
176  ? kDefaultServingSignatureDefKey
177  : request.model_spec().signature_name();
178 
179  const auto function_metadata =
180  saved_model->GetFunctionMetadata(function_name);
181  if (!function_metadata.has_value()) {
182  return errors::FailedPrecondition(
183  strings::StrCat("Function \"", function_name, "\" not found."));
184  }
185 
186  MakeModelSpec(request.model_spec().name(), function_name, servable_version,
187  response->mutable_model_spec());
188 
189  // Pre-processing.
190  TF_RETURN_IF_ERROR(PreProcessClassification(function_metadata.value()));
191  Tensor input_tensor;
192  TF_RETURN_IF_ERROR(
193  InputToSerializedExampleTensor(request.input(), &input_tensor));
194  std::vector<Tensor> input_tensors;
195  int num_examples = input_tensor.dim_size(0);
196  input_tensors.emplace_back(std::move(input_tensor));
197 
198  // Executes requests.
199  std::vector<Tensor> output_tensors;
200  const uint64_t start_microseconds = EnvTime::NowMicros();
201  if (const auto status = saved_model->Run(run_options, function_name,
202  input_tensors, &output_tensors);
203  !status.ok()) {
204  if (IsTfrtErrorLoggingEnabled()) {
205  tsl::error_logging::Log("TFRT", "SavedModelRun", status.message())
206  .IgnoreError();
207  }
208  return status;
209  }
210  const uint64_t end_microseconds = EnvTime::NowMicros();
211  RecordRuntimeLatency(request.model_spec().name(), /*api=*/"Classify",
212  /*runtime=*/"TFRT",
213  end_microseconds - start_microseconds);
214 
215  // Post-processing.
216  return PostProcessClassificationResult(
217  num_examples, function_metadata->GetOutputNames(), output_tensors,
218  response->mutable_result());
219 }
220 
221 } // namespace serving
222 } // namespace tensorflow