TensorFlow Serving C++ API Documentation
tfrt_predict_util.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
17 
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/cc/saved_model/util.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/protobuf/error_codes.pb.h"
33 #include "tensorflow/core/protobuf/named_tensor.pb.h"
34 #include "tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h"
35 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
36 #include "tsl/platform/error_logging.h"
37 #include "tensorflow_serving/apis/predict.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
39 #include "tensorflow_serving/servables/tensorflow/util.h"
40 
41 namespace tensorflow {
42 namespace serving {
43 namespace {
44 
45 // Validate the request and construct input tensor handles.
46 Status PreProcessPredictionWithoutOutputFilter(
47  const tfrt::FunctionMetadata& function_metadata,
48  const PredictRequest& request, std::vector<Tensor>* input_tensors) {
49  input_tensors->reserve(function_metadata.GetInputNames().size());
50  for (int i = 0; i < function_metadata.GetInputNames().size(); ++i) {
51  const auto& input_name = function_metadata.GetInputNames()[i];
52  const auto input = request.inputs().find(input_name);
53  if (input == request.inputs().end()) {
54  const auto& default_inputs = function_metadata.GetDefaultInputs();
55  const auto& default_input = default_inputs.find(input_name);
56  if (default_input == default_inputs.end()) {
57  const std::set<string> request_inputs = GetMapKeys(request.inputs());
58  const std::set<string> required_inputs(
59  function_metadata.GetInputNames().begin(),
60  function_metadata.GetInputNames().end());
61  const std::set<string> sent_extra =
62  SetDifference(request_inputs, required_inputs);
63  const std::set<string> missing =
64  SetDifference(SetDifference(required_inputs, request_inputs),
65  saved_model::GetMapKeys(default_inputs));
66  return errors::InvalidArgument(absl::StrCat(
67  "Request inputs do not match required inputs for model `",
68  request.model_spec().name(), "`. Send extra: {",
69  absl::StrJoin(sent_extra, ","), "}. Missing but required: {",
70  absl::StrJoin(missing, ","), "}."));
71  }
72  Tensor tensor;
73  if (!tensor.FromProto(default_input->second)) {
74  return errors::InvalidArgument(
75  absl::StrCat("tensor parsing error: ", input_name));
76  }
77  input_tensors->emplace_back(std::move(tensor));
78  continue;
79  }
80  Tensor tensor;
81  if (!tensor.FromProto(input->second)) {
82  return errors::InvalidArgument(
83  absl::StrCat("tensor parsing error: ", input_name));
84  }
85  const auto expected_dtype = function_metadata.GetInputSpecs()[i].dtype;
86  // TODO(b/188570937): Remove this type check and update related tests.
87  if (expected_dtype != DT_INVALID // Skip if the dtype is unspecified.
88  && tensor.dtype() != expected_dtype) {
89  return errors::InvalidArgument(
90  absl::StrCat("Expected input ", input_name, " to be ",
91  DataTypeString(expected_dtype), " but get ",
92  DataTypeString(tensor.dtype()), "."));
93  }
94  input_tensors->emplace_back(std::move(tensor));
95  }
96  return absl::OkStatus();
97 }
98 
99 // Validate results and populate a PredictResponse.
100 // Tensors are serialized as specified.
101 Status PostProcessPredictionResultWithoutOutputFilter(
102  const std::vector<string>& output_tensor_names,
103  const std::vector<Tensor>& output_tensors,
104  const internal::PredictResponseTensorSerializationOption option,
105  const PredictRequest& request, PredictResponse* response) {
106  if (output_tensor_names.size() != output_tensors.size()) {
107  return errors::Unknown("Predict internal error.");
108  }
109 
110  std::unordered_set<string> output_filter(request.output_filter().begin(),
111  request.output_filter().end());
112  int output_size = 0;
113  for (int i = 0; i < output_tensors.size(); ++i) {
114  if (!output_filter.empty() &&
115  output_filter.find(output_tensor_names[i]) == output_filter.end()) {
116  continue;
117  }
118  switch (option) {
119  case internal::PredictResponseTensorSerializationOption::kAsProtoField: {
120  output_tensors[i].AsProtoField(
121  &((*response->mutable_outputs())[output_tensor_names[i]]));
122  } break;
123  case internal::PredictResponseTensorSerializationOption::
124  kAsProtoContent: {
125  output_tensors[i].AsProtoTensorContent(
126  &((*response->mutable_outputs())[output_tensor_names[i]]));
127  } break;
128  }
129  output_size++;
130  }
131 
132  if (!output_filter.empty() && output_filter.size() != output_size) {
133  return errors::InvalidArgument(absl::StrCat(
134  "output_filter contains non-existed output names. output_filter: ",
135  absl::StrJoin(output_filter, ",")));
136  }
137  return absl::OkStatus();
138 }
139 
140 bool IsOutputFilterEmptyOrFullSet(
141  const PredictRequest& request,
142  const tfrt::FunctionMetadata& function_metadata) {
143  if (request.output_filter().empty()) return true;
144  if (request.output_filter().size() !=
145  function_metadata.GetOutputNames().size())
146  return false;
147  std::vector<absl::string_view> output_filter_names(
148  request.output_filter().begin(), request.output_filter().end());
149  std::vector<absl::string_view> func_output_names(
150  function_metadata.GetOutputNames().begin(),
151  function_metadata.GetOutputNames().end());
152  std::sort(output_filter_names.begin(), output_filter_names.end());
153  std::sort(func_output_names.begin(), func_output_names.end());
154  return output_filter_names == func_output_names;
155 }
156 
157 } // namespace
158 
159 namespace internal {
160 Status RunPredict(
161  const tfrt::SavedModel::RunOptions& run_options,
162  const absl::optional<int64_t>& servable_version,
163  const internal::PredictResponseTensorSerializationOption option,
164  tfrt::SavedModel* saved_model, const PredictRequest& request,
165  PredictResponse* response,
166  const thread::ThreadPoolOptions& thread_pool_options) {
167  // Validate signatures.
168  const std::string function_name =
169  request.model_spec().signature_name().empty()
170  ? kDefaultServingSignatureDefKey
171  : request.model_spec().signature_name();
172 
173  const auto function_metadata =
174  saved_model->GetFunctionMetadata(function_name);
175  if (!function_metadata.has_value()) {
176  return errors::FailedPrecondition(
177  strings::StrCat("Function \"", function_name, "\" not found."));
178  }
179 
180  MakeModelSpec(request.model_spec().name(), function_name, servable_version,
181  response->mutable_model_spec());
182 
183  auto run_opts = run_options;
184  std::optional<tensorflow::tfrt_stub::TfThreadPoolWorkQueue> thread_pool;
185  if (thread_pool_options.inter_op_threadpool != nullptr) {
186  thread_pool.emplace(
187  /*intra_op_threadpool=*/thread_pool_options.intra_op_threadpool,
188  /*inter_op_threadpool=*/thread_pool_options.inter_op_threadpool);
189  run_opts.work_queue = &(*thread_pool);
190  }
191 
192  if (IsOutputFilterEmptyOrFullSet(request, function_metadata.value())) {
193  // Pre-processing.
194  std::vector<Tensor> input_tensors;
195  TF_RETURN_IF_ERROR(PreProcessPredictionWithoutOutputFilter(
196  function_metadata.value(), request, &input_tensors));
197 
198  // Executes requests.
199  std::vector<Tensor> outputs;
200  const uint64_t start_microseconds = EnvTime::NowMicros();
201  if (const auto status =
202  saved_model->Run(run_opts, function_name, input_tensors, &outputs);
203  !status.ok()) {
204  if (IsTfrtErrorLoggingEnabled()) {
205  tsl::error_logging::Log("TFRT", "SavedModelRun", status.message())
206  .IgnoreError();
207  }
208  return status;
209  }
210  const uint64_t end_microseconds = EnvTime::NowMicros();
211  RecordRuntimeLatency(request.model_spec().name(), /*api=*/"Predict",
212  /*runtime=*/"TFRT",
213  end_microseconds - start_microseconds);
214 
215  // Post-processing.
216  return PostProcessPredictionResultWithoutOutputFilter(
217  function_metadata->GetOutputNames(), outputs, option, request,
218  response);
219  } else {
220  // When output_filter is specified, use RunByTensorNames API to trigger
221  // lazy initialization for optimized graph.
222  // RunByTensorNames is discouraged for long run, we should consider to
223  // deprecate output_filter and depends on different signature defs instead.
224  const auto& metagraph_def = saved_model->GetMetaGraphDef();
225  auto iter = metagraph_def.signature_def().find(function_name);
226  if (iter == metagraph_def.signature_def().end()) {
227  return errors::FailedPrecondition(strings::StrCat(
228  "Serving signature key \"", function_name, "\" not found."));
229  }
230  const SignatureDef& signature = iter->second;
231 
232  std::vector<std::pair<string, Tensor>> input_tensors;
233  std::vector<string> output_tensor_names;
234  std::vector<string> output_tensor_aliases;
235  TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
236  &output_tensor_names,
237  &output_tensor_aliases));
238 
239  const uint64_t start_microseconds = EnvTime::NowMicros();
240  std::vector<Tensor> outputs;
241  if (const auto status = saved_model->RunByTensorNames(
242  run_opts, input_tensors, output_tensor_names,
243  /*target_node_names=*/{}, &outputs);
244  !status.ok()) {
245  if (IsTfrtErrorLoggingEnabled()) {
246  tsl::error_logging::Log("TFRT", "SavedModelRun", status.message())
247  .IgnoreError();
248  }
249  return status;
250  }
251  const uint64_t end_microseconds = EnvTime::NowMicros();
252  RecordRuntimeLatency(request.model_spec().name(), /*api=*/"Predict",
253  /*runtime=*/"TFRT",
254  end_microseconds - start_microseconds);
255 
256  return PostProcessPredictionResult(output_tensor_aliases, outputs, option,
257  response);
258  }
259 }
260 } // namespace internal
261 
262 Status RunPredict(const tfrt::SavedModel::RunOptions& run_options,
263  const absl::optional<int64_t>& servable_version,
264  tfrt::SavedModel* saved_model, const PredictRequest& request,
265  PredictResponse* response,
266  const thread::ThreadPoolOptions& thread_pool_options) {
267  return internal::RunPredict(
268  run_options, servable_version,
269  internal::PredictResponseTensorSerializationOption::kAsProtoField,
270  saved_model, request, response, thread_pool_options);
271 }
272 
273 } // namespace serving
274 } // namespace tensorflow