TensorFlow Serving C++ API Documentation
multi_inference.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/multi_inference.h"
17 
18 #include <set>
19 #include <vector>
20 
21 #include "tensorflow/cc/saved_model/signature_constants.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/platform/tracing.h"
24 #include "tensorflow_serving/apis/input.pb.h"
25 #include "tensorflow_serving/apis/model.pb.h"
26 #include "tensorflow_serving/servables/tensorflow/classifier.h"
27 #include "tensorflow_serving/servables/tensorflow/regressor.h"
28 #include "tensorflow_serving/servables/tensorflow/util.h"
29 
30 namespace tensorflow {
31 namespace serving {
32 
33 Status TensorFlowMultiInferenceRunner::Infer(
34  const RunOptions& run_options, const MultiInferenceRequest& request,
35  MultiInferenceResponse* response) {
36  TRACELITERAL("TensorFlowMultiInferenceRunner::Infer");
37 
38  string model_name = "";
39  std::set<string> signature_names;
40  std::set<string> input_tensor_name_set;
41  std::set<string> output_tensor_name_set;
42  for (const auto& task : request.tasks()) {
43  if (task.model_spec().name().empty()) {
44  return errors::InvalidArgument(
45  "Found ModelSpec with an empty model name.");
46  }
47  if (model_name.empty()) {
48  model_name = task.model_spec().name();
49  } else if (model_name != task.model_spec().name()) {
50  return errors::InvalidArgument(
51  "All ModelSpecs in a MultiInferenceRequest must access the same "
52  "model name.");
53  }
54 
55  const string signature_name = task.model_spec().signature_name().empty()
56  ? kDefaultServingSignatureDefKey
57  : task.model_spec().signature_name();
58 
59  if (signature_names.find(signature_name) != signature_names.end()) {
60  return errors::InvalidArgument(strings::StrCat(
61  "Duplicate evaluation of signature: ", signature_name));
62  }
63  signature_names.insert(signature_name);
64 
65  auto iter = meta_graph_def_->signature_def().find(signature_name);
66  if (iter == meta_graph_def_->signature_def().end()) {
67  return errors::InvalidArgument(strings::StrCat(
68  "Requested signature not found in model graph: ", signature_name));
69  }
70  string input_name;
71  std::vector<string> output_names;
72 
73  if (task.method_name() == kClassifyMethodName) {
74  TF_RETURN_IF_ERROR(
75  PreProcessClassification(iter->second, &input_name, &output_names));
76  } else if (task.method_name() == kRegressMethodName) {
77  TF_RETURN_IF_ERROR(
78  PreProcessRegression(iter->second, &input_name, &output_names));
79  } else {
80  return errors::Unimplemented("Unsupported signature method_name: ",
81  task.method_name());
82  }
83  input_tensor_name_set.insert(input_name);
84  for (const auto& output_tensor_name : output_names) {
85  output_tensor_name_set.insert(output_tensor_name);
86  }
87  }
88 
89  const std::vector<string> output_tensor_names(output_tensor_name_set.begin(),
90  output_tensor_name_set.end());
91 
92  std::vector<Tensor> outputs;
93  int num_examples;
94  TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
95  run_options, request.input(), input_tensor_name_set, output_tensor_names,
96  session_, &outputs, &num_examples, thread_pool_options_));
97  RecordRequestExampleCount(model_name, num_examples);
98 
99  TRACELITERAL("PostProcessResults");
100  for (const auto& task : request.tasks()) {
101  const string signature_name = task.model_spec().signature_name().empty()
102  ? kDefaultServingSignatureDefKey
103  : task.model_spec().signature_name();
104  auto iter = meta_graph_def_->signature_def().find(signature_name);
105  if (iter == meta_graph_def_->signature_def().end()) {
106  return errors::InvalidArgument(strings::StrCat(
107  "Requested signature not found in model graph: ", signature_name));
108  }
109  if (task.method_name() == kClassifyMethodName) {
110  TF_RETURN_IF_ERROR(PostProcessClassificationResult(
111  iter->second, num_examples, output_tensor_names, outputs,
112  response->add_results()->mutable_classification_result()));
113  } else if (task.method_name() == kRegressMethodName) {
114  TF_RETURN_IF_ERROR(PostProcessRegressionResult(
115  iter->second, num_examples, output_tensor_names, outputs,
116  response->add_results()->mutable_regression_result()));
117  } else {
118  return errors::InvalidArgument("Unrecognized signature method_name: ",
119  task.method_name());
120  }
121  MakeModelSpec(task.model_spec().name(), task.model_spec().signature_name(),
122  servable_version_,
123  response->mutable_results(response->results_size() - 1)
124  ->mutable_model_spec());
125  }
126  return absl::OkStatus();
127 }
128 
129 Status RunMultiInference(
130  const RunOptions& run_options, const MetaGraphDef& meta_graph_def,
131  const absl::optional<int64_t>& servable_version, Session* session,
132  const MultiInferenceRequest& request, MultiInferenceResponse* response,
133  const tensorflow::thread::ThreadPoolOptions& thread_pool_options) {
134  TRACELITERAL("RunMultiInference");
135 
136  TensorFlowMultiInferenceRunner inference_runner(
137  session, &meta_graph_def, servable_version, thread_pool_options);
138  return inference_runner.Infer(run_options, request, response);
139 }
140 
141 } // namespace serving
142 } // namespace tensorflow