TensorFlow Serving C++ API Documentation
predict_util.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
17 
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/cc/saved_model/util.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/protobuf/named_tensor.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/util.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 namespace {
37 
38 Status VerifySignature(const SignatureDef& signature) {
39  if (GetSignatureMethodNameCheckFeature() &&
40  signature.method_name() != kPredictMethodName &&
41  signature.method_name() != kClassifyMethodName &&
42  signature.method_name() != kRegressMethodName) {
43  return errors::Internal(strings::StrCat(
44  "Expected prediction signature method_name to be one of {",
45  kPredictMethodName, ", ", kClassifyMethodName, ", ", kRegressMethodName,
46  "}. Was: ", signature.method_name()));
47  }
48  return absl::OkStatus();
49 }
50 
51 Status VerifyRequestInputsSize(const SignatureDef& signature,
52  const PredictRequest& request) {
53  if (request.inputs().size() > signature.inputs().size() ||
54  (request.inputs().size() < signature.inputs().size() &&
55  signature.defaults().empty())) {
56  const std::set<string> request_inputs = GetMapKeys(request.inputs());
57  const std::set<string> signature_inputs = GetMapKeys(signature.inputs());
58  const std::set<string> sent_extra =
59  SetDifference(request_inputs, signature_inputs);
60  const std::set<string> missing =
61  SetDifference(signature_inputs, request_inputs);
62  return tensorflow::Status(
63  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
64  absl::StrCat(
65  "input size does not match signature: ", request.inputs().size(),
66  "!=", signature.inputs().size(), " len({",
67  absl::StrJoin(request_inputs, ","), "}) != len({",
68  absl::StrJoin(signature_inputs, ","), "}). Sent extra: {",
69  absl::StrJoin(sent_extra, ","), "}. Missing but required: {",
70  absl::StrJoin(missing, ","), "}."));
71  }
72  return absl::OkStatus();
73 }
74 
75 } // namespace
76 
77 namespace internal {
78 Status RunPredict(
79  const RunOptions& run_options, const MetaGraphDef& meta_graph_def,
80  const absl::optional<int64_t>& servable_version,
81  const internal::PredictResponseTensorSerializationOption option,
82  Session* session, const PredictRequest& request, PredictResponse* response,
83  const thread::ThreadPoolOptions& thread_pool_options) {
84  // Validate signatures.
85  const string signature_name = request.model_spec().signature_name().empty()
86  ? kDefaultServingSignatureDefKey
87  : request.model_spec().signature_name();
88  auto iter = meta_graph_def.signature_def().find(signature_name);
89  if (iter == meta_graph_def.signature_def().end()) {
90  return errors::FailedPrecondition(strings::StrCat(
91  "Serving signature key \"", signature_name, "\" not found."));
92  }
93  const SignatureDef& signature = iter->second;
94 
95  MakeModelSpec(request.model_spec().name(), signature_name, servable_version,
96  response->mutable_model_spec());
97 
98  std::vector<std::pair<string, Tensor>> input_tensors;
99  std::vector<string> output_tensor_names;
100  std::vector<string> output_tensor_aliases;
101  TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
102  &output_tensor_names,
103  &output_tensor_aliases));
104  std::vector<Tensor> outputs;
105  RunMetadata run_metadata;
106  const uint64_t start_microseconds = EnvTime::NowMicros();
107  TF_RETURN_IF_ERROR(session->Run(run_options, input_tensors,
108  output_tensor_names, {}, &outputs,
109  &run_metadata, thread_pool_options));
110  const uint64_t end_microseconds = EnvTime::NowMicros();
111  RecordRuntimeLatency(request.model_spec().name(), /*api=*/"Predict",
112  /*runtime=*/"TF1",
113  end_microseconds - start_microseconds);
114 
115  return PostProcessPredictionResult(output_tensor_aliases, outputs, option,
116  response);
117 }
118 
119 Status PreProcessPrediction(const SignatureDef& signature,
120  const PredictRequest& request,
121  std::vector<std::pair<string, Tensor>>* inputs,
122  std::vector<string>* output_tensor_names,
123  std::vector<string>* output_tensor_aliases) {
124  TF_RETURN_IF_ERROR(VerifySignature(signature));
125  TF_RETURN_IF_ERROR(VerifyRequestInputsSize(signature, request));
126  TF_RETURN_IF_ERROR(
127  saved_model::GetInputValues(signature, request.inputs(), *inputs));
128 
129  // Prepare run target.
130  std::set<string> seen_outputs;
131  std::vector<string> output_filter(request.output_filter().begin(),
132  request.output_filter().end());
133  for (auto& alias : output_filter) {
134  auto iter = signature.outputs().find(alias);
135  if (iter == signature.outputs().end()) {
136  return tensorflow::Status(
137  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
138  strings::StrCat("output tensor alias not found in signature: ", alias,
139  " Outputs expected to be in the set {",
140  absl::StrJoin(GetMapKeys(signature.outputs()), ","),
141  "}."));
142  }
143  if (seen_outputs.find(alias) != seen_outputs.end()) {
144  return tensorflow::Status(
145  static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
146  "duplicate output tensor alias: " + alias);
147  }
148  seen_outputs.insert(alias);
149  output_tensor_names->emplace_back(iter->second.name());
150  output_tensor_aliases->emplace_back(alias);
151  }
152  // When no output is specified, fetch all output tensors specified in
153  // the signature.
154  if (output_tensor_names->empty()) {
155  for (auto& iter : signature.outputs()) {
156  output_tensor_names->emplace_back(iter.second.name());
157  output_tensor_aliases->emplace_back(iter.first);
158  }
159  }
160  return absl::OkStatus();
161 }
162 
163 Status PostProcessPredictionResult(
164  const std::vector<string>& output_tensor_aliases,
165  const std::vector<Tensor>& output_tensors,
166  const internal::PredictResponseTensorSerializationOption option,
167  PredictResponse* response) {
168  // Validate and return output.
169  if (output_tensors.size() != output_tensor_aliases.size()) {
170  return tensorflow::Status(
171  static_cast<tensorflow::errors::Code>(absl::StatusCode::kUnknown),
172  "Predict internal error");
173  }
174  switch (option) {
175  case internal::PredictResponseTensorSerializationOption::kAsProtoField: {
176  for (int i = 0; i < output_tensors.size(); i++) {
177  output_tensors[i].AsProtoField(
178  &((*response->mutable_outputs())[output_tensor_aliases[i]]));
179  }
180  } break;
181  case internal::PredictResponseTensorSerializationOption::kAsProtoContent: {
182  for (int i = 0; i < output_tensors.size(); i++) {
183  output_tensors[i].AsProtoTensorContent(
184  &((*response->mutable_outputs())[output_tensor_aliases[i]]));
185  }
186  } break;
187  }
188 
189  return absl::OkStatus();
190 }
191 
192 } // namespace internal
193 
194 Status RunPredict(const RunOptions& run_options,
195  const MetaGraphDef& meta_graph_def,
196  const absl::optional<int64_t>& servable_version,
197  Session* session, const PredictRequest& request,
198  PredictResponse* response,
199  const thread::ThreadPoolOptions& thread_pool_options) {
200  return internal::RunPredict(
201  run_options, meta_graph_def, servable_version,
202  internal::PredictResponseTensorSerializationOption::kAsProtoField,
203  session, request, response, thread_pool_options);
204 }
205 
206 } // namespace serving
207 } // namespace tensorflow