TensorFlow Serving C++ API Documentation
tfrt_multi_inference.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
17 
18 #include "tensorflow/cc/saved_model/loader.h"
19 #include "tensorflow/cc/saved_model/signature_constants.h"
20 #include "tensorflow/core/tfrt/utils/tensor_util.h"
21 #include "tsl/platform/error_logging.h"
22 #include "tensorflow_serving/apis/input.pb.h"
23 #include "tensorflow_serving/apis/model.pb.h"
24 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
25 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
26 #include "tensorflow_serving/servables/tensorflow/util.h"
27 
28 namespace tensorflow {
29 namespace serving {
30 
31 Status RunMultiInference(const tfrt::SavedModel::RunOptions& run_options,
32  const absl::optional<int64_t>& servable_version,
33  tfrt::SavedModel* saved_model,
34  const MultiInferenceRequest& request,
35  MultiInferenceResponse* response) {
36  Tensor input_tensor;
37  TF_RETURN_IF_ERROR(
38  InputToSerializedExampleTensor(request.input(), &input_tensor));
39  std::vector<std::vector<Tensor>> input_tensors;
40  int num_examples = input_tensor.dim_size(0);
41  input_tensors.resize(request.tasks_size());
42  for (int i = 0; i < request.tasks_size(); ++i) {
43  input_tensors[i].emplace_back(input_tensor);
44  }
45 
46  // Pre-processing.
47  std::string model_name = "";
48  std::set<std::string> function_names_set;
49  std::vector<std::string> function_names;
50  function_names.reserve(request.tasks_size());
51  for (const auto& task : request.tasks()) {
52  if (task.model_spec().name().empty()) {
53  return errors::InvalidArgument(
54  "Found ModelSpec with an empty model name.");
55  }
56  if (model_name.empty()) {
57  model_name = task.model_spec().name();
58  } else if (model_name != task.model_spec().name()) {
59  return errors::InvalidArgument(
60  "All ModelSpecs in a MultiInferenceRequest must access the same "
61  "model name.");
62  }
63 
64  const std::string function_name = task.model_spec().signature_name().empty()
65  ? kDefaultServingSignatureDefKey
66  : task.model_spec().signature_name();
67 
68  // TODO(b/183949363): Remove the constrain here. We could allow duplicated
69  // function names and simply return result for each of them.
70  if (function_names_set.find(function_name) != function_names_set.end()) {
71  return errors::InvalidArgument(strings::StrCat(
72  "Duplicate evaluation of signature: ", function_name));
73  }
74  function_names_set.insert(function_name);
75  function_names.push_back(function_name);
76 
77  const auto function_metadata =
78  saved_model->GetFunctionMetadata(function_name);
79  if (!function_metadata.has_value()) {
80  return errors::InvalidArgument(
81  strings::StrCat("Function \"", function_name, "\" not found."));
82  }
83 
84  if (task.method_name() == kClassifyMethodName) {
85  TF_RETURN_IF_ERROR(PreProcessClassification(function_metadata.value()));
86  } else if (task.method_name() == kRegressMethodName) {
87  TF_RETURN_IF_ERROR(PreProcessRegression(function_metadata.value()));
88  } else {
89  return errors::Unimplemented("Unsupported signature method_name: ",
90  task.method_name());
91  }
92  }
93 
94  // Executes requests.
95  std::vector<std::vector<Tensor>> output_tensors;
96  if (const auto status = saved_model->RunMultipleSignatures(
97  run_options, function_names, input_tensors, &output_tensors);
98  !status.ok()) {
99  if (IsTfrtErrorLoggingEnabled()) {
100  tsl::error_logging::Log("TFRT", "SavedModelRun", status.message())
101  .IgnoreError();
102  }
103  return status;
104  }
105 
106  // Post-processing.
107  for (int i = 0; i < request.tasks_size(); ++i) {
108  // We have already checked the existence of the function metadata before
109  // execution.
110  const auto function_metadata =
111  saved_model->GetFunctionMetadata(function_names[i]);
112  DCHECK(function_metadata.has_value());
113  if (request.tasks(i).method_name() == kClassifyMethodName) {
114  TF_RETURN_IF_ERROR(PostProcessClassificationResult(
115  num_examples, function_metadata->GetOutputNames(), output_tensors[i],
116  response->add_results()->mutable_classification_result()));
117  } else if (request.tasks(i).method_name() == kRegressMethodName) {
118  TF_RETURN_IF_ERROR(PostProcessRegressionResult(
119  num_examples, function_metadata->GetOutputNames(), output_tensors[i],
120  response->add_results()->mutable_regression_result()));
121  } else {
122  return errors::InvalidArgument("Unrecognized signature method_name: ",
123  request.tasks(i).method_name());
124  }
125  MakeModelSpec(request.tasks(i).model_spec().name(),
126  request.tasks(i).model_spec().signature_name(),
127  servable_version,
128  response->mutable_results(response->results_size() - 1)
129  ->mutable_model_spec());
130  }
131  RecordRequestExampleCount(model_name, num_examples);
132  return absl::OkStatus();
133 }
134 
135 } // namespace serving
136 } // namespace tensorflow