TensorFlow Serving C++ API Documentation
classifier.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/classifier.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/example/example.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/platform/threadpool_options.h"
35 #include "tensorflow/core/platform/tracing.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow_serving/apis/classification.pb.h"
38 #include "tensorflow_serving/apis/classifier.h"
39 #include "tensorflow_serving/apis/input.pb.h"
40 #include "tensorflow_serving/apis/model.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 namespace {
46 
47 // Implementation of the ClassifierInterface using SavedModel.
48 class SavedModelTensorFlowClassifier : public ClassifierInterface {
49  public:
50  explicit SavedModelTensorFlowClassifier(
51  const RunOptions& run_options, Session* session,
52  const SignatureDef* const signature,
53  const thread::ThreadPoolOptions& thread_pool_options =
54  thread::ThreadPoolOptions())
55  : run_options_(run_options),
56  session_(session),
57  signature_(signature),
58  thread_pool_options_(thread_pool_options) {}
59 
60  ~SavedModelTensorFlowClassifier() override = default;
61 
62  Status Classify(const ClassificationRequest& request,
63  ClassificationResult* result) override {
64  TRACELITERAL("TensorFlowClassifier::Classify");
65 
66  string input_tensor_name;
67  std::vector<string> output_tensor_names;
68  TF_RETURN_IF_ERROR(PreProcessClassification(*signature_, &input_tensor_name,
69  &output_tensor_names));
70 
71  std::vector<Tensor> outputs;
72  int num_examples;
73  int64_t runtime_latency;
74  TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
75  run_options_, request.input(), input_tensor_name, output_tensor_names,
76  session_, &outputs, &num_examples, thread_pool_options_,
77  &runtime_latency));
78  RecordRuntimeLatency(request.model_spec().name(), /*api=*/"Classify",
79  /*runtime=*/"TF1", runtime_latency);
80 
81  TRACELITERAL("ConvertToClassificationResult");
82  return PostProcessClassificationResult(
83  *signature_, num_examples, output_tensor_names, outputs, result);
84  }
85 
86  private:
87  const RunOptions run_options_;
88  Session* const session_;
89  const SignatureDef* const signature_;
90  const thread::ThreadPoolOptions thread_pool_options_;
91 
92  TF_DISALLOW_COPY_AND_ASSIGN(SavedModelTensorFlowClassifier);
93 };
94 
95 class SavedModelClassifier : public ClassifierInterface {
96  public:
97  SavedModelClassifier(const RunOptions& run_options,
98  std::unique_ptr<SavedModelBundle> bundle)
99  : run_options_(run_options), bundle_(std::move(bundle)) {}
100 
101  ~SavedModelClassifier() override = default;
102 
103  Status Classify(const ClassificationRequest& request,
104  ClassificationResult* result) override {
105  // Get the default signature of the graph. Expected to be a
106  // classification signature.
107  // TODO(b/26220896): Move TensorFlowClassifier creation to construction
108  // time.
109  SignatureDef signature;
110  TF_RETURN_IF_ERROR(GetClassificationSignatureDef(
111  request.model_spec(), bundle_->meta_graph_def, &signature));
112  SavedModelTensorFlowClassifier classifier(
113  run_options_, bundle_->session.get(), &signature);
114  return classifier.Classify(request, result);
115  }
116 
117  private:
118  const RunOptions run_options_;
119  std::unique_ptr<SavedModelBundle> bundle_;
120 
121  TF_DISALLOW_COPY_AND_ASSIGN(SavedModelClassifier);
122 };
123 
124 } // namespace
125 
126 Status CreateClassifierFromSavedModelBundle(
127  const RunOptions& run_options, std::unique_ptr<SavedModelBundle> bundle,
128  std::unique_ptr<ClassifierInterface>* service) {
129  service->reset(new SavedModelClassifier(run_options, std::move(bundle)));
130  return absl::OkStatus();
131 }
132 
133 Status CreateFlyweightTensorFlowClassifier(
134  const RunOptions& run_options, Session* session,
135  const SignatureDef* signature,
136  std::unique_ptr<ClassifierInterface>* service) {
137  return CreateFlyweightTensorFlowClassifier(
138  run_options, session, signature, thread::ThreadPoolOptions(), service);
139 }
140 
141 Status CreateFlyweightTensorFlowClassifier(
142  const RunOptions& run_options, Session* session,
143  const SignatureDef* signature,
144  const thread::ThreadPoolOptions& thread_pool_options,
145  std::unique_ptr<ClassifierInterface>* service) {
146  service->reset(new SavedModelTensorFlowClassifier(
147  run_options, session, signature, thread_pool_options));
148  return absl::OkStatus();
149 }
150 
151 Status GetClassificationSignatureDef(const ModelSpec& model_spec,
152  const MetaGraphDef& meta_graph_def,
153  SignatureDef* signature) {
154  const string signature_name = model_spec.signature_name().empty()
155  ? kDefaultServingSignatureDefKey
156  : model_spec.signature_name();
157  auto iter = meta_graph_def.signature_def().find(signature_name);
158  if (iter == meta_graph_def.signature_def().end()) {
159  return errors::InvalidArgument(strings::StrCat(
160  "No signature was found with the name: ", signature_name));
161  }
162  if (GetSignatureMethodNameCheckFeature()) {
163  if (iter->second.method_name() != kClassifyMethodName) {
164  return errors::InvalidArgument(strings::StrCat(
165  "Expected classification signature method_name to be ",
166  kClassifyMethodName, ". Was: ", iter->second.method_name()));
167  }
168  } else {
169  TF_RETURN_IF_ERROR(
170  PreProcessClassification(iter->second, nullptr, nullptr));
171  }
172  *signature = iter->second;
173  return absl::OkStatus();
174 }
175 
176 Status PreProcessClassification(const SignatureDef& signature,
177  string* input_tensor_name,
178  std::vector<string>* output_tensor_names) {
179  if (GetSignatureMethodNameCheckFeature() &&
180  signature.method_name() != kClassifyMethodName) {
181  return errors::InvalidArgument(strings::StrCat(
182  "Expected classification signature method_name to be ",
183  kClassifyMethodName, ". Was: ", signature.method_name()));
184  }
185  if (signature.inputs().size() != 1) {
186  return errors::InvalidArgument(
187  strings::StrCat("Expected one input Tensor."));
188  }
189  if (signature.outputs().size() != 1 && signature.outputs().size() != 2) {
190  return errors::InvalidArgument(
191  strings::StrCat("Expected one or two output Tensors, found ",
192  signature.outputs().size()));
193  }
194 
195  auto input_iter = signature.inputs().find(kClassifyInputs);
196  if (input_iter == signature.inputs().end()) {
197  return errors::InvalidArgument(
198  "No classification inputs found in SignatureDef: ",
199  signature.DebugString());
200  }
201  if (input_tensor_name != nullptr) {
202  *input_tensor_name = input_iter->second.name();
203  }
204 
205  auto classes_iter = signature.outputs().find(kClassifyOutputClasses);
206  auto scores_iter = signature.outputs().find(kClassifyOutputScores);
207  if (classes_iter == signature.outputs().end() &&
208  scores_iter == signature.outputs().end()) {
209  return errors::InvalidArgument(strings::StrCat(
210  "Expected classification signature outputs to contain at least one of ",
211  "\"", kClassifyOutputClasses, "\" or \"", kClassifyOutputScores,
212  "\". Signature was: ", signature.DebugString()));
213  }
214  if (output_tensor_names != nullptr) {
215  if (classes_iter != signature.outputs().end()) {
216  output_tensor_names->push_back(classes_iter->second.name());
217  }
218  if (scores_iter != signature.outputs().end()) {
219  output_tensor_names->push_back(scores_iter->second.name());
220  }
221  }
222  return absl::OkStatus();
223 }
224 
225 Status PostProcessClassificationResult(
226  const SignatureDef& signature, int num_examples,
227  const std::vector<string>& output_tensor_names,
228  const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
229  if (output_tensors.size() != output_tensor_names.size()) {
230  return errors::InvalidArgument(
231  strings::StrCat("Expected ", output_tensor_names.size(),
232  " output tensor(s). Got: ", output_tensors.size()));
233  }
234 
235  auto classes_iter = signature.outputs().find(kClassifyOutputClasses);
236  string classes_tensor_name;
237  if (classes_iter != signature.outputs().end()) {
238  classes_tensor_name = classes_iter->second.name();
239  }
240  auto scores_iter = signature.outputs().find(kClassifyOutputScores);
241  string scores_tensor_name;
242  if (scores_iter != signature.outputs().end()) {
243  scores_tensor_name = scores_iter->second.name();
244  }
245 
246  const Tensor* classes = nullptr;
247  const Tensor* scores = nullptr;
248  for (int i = 0; i < output_tensors.size(); ++i) {
249  if (output_tensor_names[i] == classes_tensor_name) {
250  classes = &output_tensors[i];
251  } else if (output_tensor_names[i] == scores_tensor_name) {
252  scores = &output_tensors[i];
253  }
254  }
255 
256  // Validate classes output Tensor.
257  if (classes) {
258  if (classes->dims() != 2) {
259  return errors::InvalidArgument(
260  "Expected Tensor shape: [batch_size num_classes] but got ",
261  classes->shape().DebugString());
262  }
263  if (classes->dtype() != DT_STRING) {
264  return errors::InvalidArgument(
265  "Expected classes Tensor of DT_STRING. Got: ",
266  DataType_Name(classes->dtype()));
267  }
268  if (classes->dim_size(0) != num_examples) {
269  return errors::InvalidArgument("Expected classes output batch size of ",
270  num_examples,
271  ". Got: ", classes->dim_size(0));
272  }
273  }
274  // Validate scores output Tensor.
275  if (scores) {
276  if (scores->dims() != 2) {
277  return errors::InvalidArgument(
278  "Expected Tensor shape: [batch_size num_classes] but got ",
279  scores->shape().DebugString());
280  }
281  if (scores->dtype() != DT_FLOAT) {
282  return errors::InvalidArgument(
283  "Expected scores Tensor of DT_FLOAT. Got: ",
284  DataType_Name(scores->dtype()));
285  }
286  if (scores->dim_size(0) != num_examples) {
287  return errors::InvalidArgument("Expected scores output batch size of ",
288  num_examples,
289  ". Got: ", scores->dim_size(0));
290  }
291  }
292  // Extract the number of classes from either the class or score output
293  // Tensor.
294  int num_classes = 0;
295  if (classes && scores) {
296  // If we have both Tensors they should agree in the second dimmension.
297  if (classes->dim_size(1) != scores->dim_size(1)) {
298  return errors::InvalidArgument(
299  "Tensors class and score should match in dim_size(1). Got ",
300  classes->dim_size(1), " vs. ", scores->dim_size(1));
301  }
302  num_classes = classes->dim_size(1);
303  } else if (classes) {
304  num_classes = classes->dim_size(1);
305  } else if (scores) {
306  num_classes = scores->dim_size(1);
307  }
308 
309  // Convert the output to ClassificationResult format.
310  for (int i = 0; i < num_examples; ++i) {
311  serving::Classifications* classifications = result->add_classifications();
312  for (int c = 0; c < num_classes; ++c) {
313  serving::Class* cl = classifications->add_classes();
314  if (classes) {
315  const tstring& class_tstr = (classes->matrix<tstring>())(i, c);
316  cl->set_label(class_tstr.data(), class_tstr.size());
317  }
318  if (scores) {
319  cl->set_score((scores->matrix<float>())(i, c));
320  }
321  }
322  }
323  return absl::OkStatus();
324 }
325 
326 Status RunClassify(const RunOptions& run_options,
327  const MetaGraphDef& meta_graph_def,
328  const absl::optional<int64_t>& servable_version,
329  Session* session, const ClassificationRequest& request,
330  ClassificationResponse* response,
331  const thread::ThreadPoolOptions& thread_pool_options) {
332  SignatureDef signature;
333  TF_RETURN_IF_ERROR(GetClassificationSignatureDef(request.model_spec(),
334  meta_graph_def, &signature));
335 
336  std::unique_ptr<ClassifierInterface> classifier_interface;
337  TF_RETURN_IF_ERROR(CreateFlyweightTensorFlowClassifier(
338  run_options, session, &signature, thread_pool_options,
339  &classifier_interface));
340 
341  MakeModelSpec(request.model_spec().name(),
342  request.model_spec().signature_name(), servable_version,
343  response->mutable_model_spec());
344 
345  // Run classification.
346  return classifier_interface->Classify(request, response->mutable_result());
347 }
348 
349 } // namespace serving
350 } // namespace tensorflow