TensorFlow Serving C++ API Documentation
tfrt_regressor.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/notification.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/threadpool_options.h"
33 #include "tensorflow/core/platform/tracing.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/tfrt/utils/tensor_util.h"
36 #include "tsl/platform/error_logging.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/regression.pb.h"
40 #include "tensorflow_serving/apis/regressor.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 
46 Status PreProcessRegression(const tfrt::FunctionMetadata& function_metadata) {
47  if (function_metadata.GetInputNames().size() != 1) {
48  return errors::InvalidArgument(
49  strings::StrCat("Expected one input Tensor."));
50  }
51  if (function_metadata.GetOutputNames().size() != 1) {
52  return errors::InvalidArgument(
53  strings::StrCat("Expected one output Tensor."));
54  }
55 
56  if (function_metadata.GetInputNames()[0] != kRegressInputs) {
57  return errors::FailedPrecondition(
58  "No regression inputs found in function's metadata, only contains: ",
59  function_metadata.GetInputNames()[0]);
60  }
61 
62  if (function_metadata.GetOutputNames()[0] != kRegressOutputs) {
63  return errors::FailedPrecondition(
64  "No regression outputs found in function's metadata, only contains: ",
65  function_metadata.GetOutputNames()[0]);
66  }
67 
68  return absl::OkStatus();
69 }
70 
71 Status PostProcessRegressionResult(
72  int num_examples, const std::vector<string>& output_tensor_names,
73  const std::vector<Tensor>& output_tensors, RegressionResult* result) {
74  if (output_tensors.size() != output_tensor_names.size()) {
75  return errors::InvalidArgument(
76  "Expected output_tensors and output_tensor_names to have the same "
77  "size.");
78  }
79 
80  const Tensor* output_tensor = &output_tensors[0];
81 
82  if (!(output_tensor->dims() == 1 ||
83  (output_tensor->dims() == 2 && output_tensor->dim_size(1) == 1))) {
84  return errors::InvalidArgument(
85  "Expected output Tensor shape to be either [batch_size] or ",
86  "[batch_size, 1] but got ", output_tensor->shape().DebugString());
87  }
88  if (num_examples != output_tensor->dim_size(0)) {
89  return errors::InvalidArgument(strings::StrCat(
90  "Input batch size did not match output batch size: ", num_examples,
91  " vs. ", output_tensor->dim_size(0)));
92  }
93  if (output_tensor->dtype() != DT_FLOAT) {
94  return errors::InvalidArgument("Expected output Tensor of DT_FLOAT. Got: ",
95  DataType_Name(output_tensor->dtype()));
96  }
97 
98  if (output_tensor->NumElements() != num_examples) {
99  return errors::InvalidArgument("Expected output batch size to be ",
100  num_examples,
101  ". Got: ", output_tensor->NumElements());
102  }
103 
104  const auto& output_tensor_flat = output_tensor->flat<float>();
105  for (int i = 0; i < num_examples; ++i) {
106  result->add_regressions()->set_value(output_tensor_flat(i));
107  }
108  return absl::OkStatus();
109 }
110 
111 Status RunRegress(const tfrt::SavedModel::RunOptions& run_options,
112  const absl::optional<int64_t>& servable_version,
113  tfrt::SavedModel* saved_model,
114  const RegressionRequest& request,
115  RegressionResponse* response) {
116  const string function_name = request.model_spec().signature_name().empty()
117  ? kDefaultServingSignatureDefKey
118  : request.model_spec().signature_name();
119 
120  const auto function_metadata =
121  saved_model->GetFunctionMetadata(function_name);
122  if (!function_metadata.has_value()) {
123  return errors::FailedPrecondition(
124  strings::StrCat("Function \"", function_name, "\" not found."));
125  }
126 
127  MakeModelSpec(request.model_spec().name(), function_name, servable_version,
128  response->mutable_model_spec());
129 
130  // Pre-processing.
131  TF_RETURN_IF_ERROR(PreProcessRegression(function_metadata.value()));
132  Tensor input_tensor;
133  TF_RETURN_IF_ERROR(
134  InputToSerializedExampleTensor(request.input(), &input_tensor));
135  std::vector<Tensor> input_tensors;
136  int num_examples = input_tensor.dim_size(0);
137  input_tensors.emplace_back(std::move(input_tensor));
138 
139  // Executes requests.
140  std::vector<Tensor> output_tensors;
141  const uint64_t start_microseconds = EnvTime::NowMicros();
142  if (const auto status = saved_model->Run(run_options, function_name,
143  input_tensors, &output_tensors);
144  !status.ok()) {
145  if (IsTfrtErrorLoggingEnabled()) {
146  tsl::error_logging::Log("TFRT", "SavedModelRun", status.message())
147  .IgnoreError();
148  }
149  return status;
150  }
151  const uint64_t end_microseconds = EnvTime::NowMicros();
152  RecordRuntimeLatency(request.model_spec().name(),
153  /*api=*/"Regress", /*runtime=*/"TFRT",
154  end_microseconds - start_microseconds);
155 
156  // Post-processing.
157  return PostProcessRegressionResult(
158  num_examples, function_metadata->GetOutputNames(), output_tensors,
159  response->mutable_result());
160 }
161 
162 } // namespace serving
163 } // namespace tensorflow