TensorFlow Serving C++ API Documentation
tfrt_servable.cc
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
17 
18 #include <stdint.h>
19 
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/functional/any_invocable.h"
29 #include "absl/status/status.h"
30 #include "absl/status/statusor.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/synchronization/mutex.h"
34 #include "absl/time/time.h"
35 #include "tensorflow/cc/saved_model/signature_constants.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/platform/tracing.h" // NOLINT
39 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
40 #include "tsl/platform/errors.h"
41 #include "tsl/platform/statusor.h"
42 #include "tsl/platform/threadpool_options.h"
43 #include "tensorflow_serving/apis/classification.pb.h"
44 #include "tensorflow_serving/apis/get_model_metadata.pb.h"
45 #include "tensorflow_serving/apis/inference.pb.h"
46 #include "tensorflow_serving/apis/predict.pb.h"
47 #include "tensorflow_serving/apis/regression.pb.h"
48 #include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
49 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
50 #include "tensorflow_serving/servables/tensorflow/servable.h"
51 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
52 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
53 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
54 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
55 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
57 
58 namespace tensorflow {
59 namespace serving {
60 
61 TfrtSavedModelServable::TfrtSavedModelServable(
62  absl::string_view name, int64_t version, const TfrtSavedModelConfig& config,
63  const SavedModelConfig& model_config,
64  std::unique_ptr<tfrt_stub::SavedModel> saved_model,
65  ThreadPoolFactory* thread_pool_factory,
66  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
67  recorder_creator)
68  : Servable(name, version, model_config.critical()),
69  saved_model_(std::move(saved_model)),
70  config_(config),
71  thread_pool_factory_(thread_pool_factory),
72  recorder_creator_(std::move(recorder_creator)) {
73  switch (config_.predict_response_tensor_serialization_option()) {
74  case TfrtSavedModelConfig::AS_PROTO_FIELD: {
75  predict_response_tensor_serialization_option_ =
76  internal::PredictResponseTensorSerializationOption::kAsProtoField;
77  break;
78  }
79  case TfrtSavedModelConfig::AS_PROTO_CONTENT: {
80  predict_response_tensor_serialization_option_ =
81  internal::PredictResponseTensorSerializationOption::kAsProtoContent;
82  break;
83  }
84  default: {
85  predict_response_tensor_serialization_option_ =
86  internal::PredictResponseTensorSerializationOption::kAsProtoField;
87  break;
88  }
89  }
90 }
91 
92 tfrt_stub::SavedModel::RunOptions
93 TfrtSavedModelServable::GetTFRTSavedModelRunOptions(
94  const Servable::RunOptions& run_options) const {
95  tfrt_stub::SavedModel::RunOptions options;
96  if (run_options.deadline != absl::InfiniteFuture()) {
97  options.deadline = absl::ToChronoTime(run_options.deadline);
98  }
99  options.validate_input_specs = config_.validate_input_specs();
100  options.validate_input_specs_dry_run = config_.validate_input_specs_dry_run();
101  return options;
102 }
103 
104 absl::Status TfrtSavedModelServable::Classify(
105  const RunOptions& run_options, const ClassificationRequest& request,
106  ClassificationResponse* response) {
107  TRACELITERAL("TfrtSavedModelServable::Classify");
108  auto recorder = CreateRecorder();
109  return RunClassify(GetTFRTSavedModelRunOptions(run_options), version(),
110  saved_model_.get(), request, response);
111 }
112 
113 absl::Status TfrtSavedModelServable::Regress(const RunOptions& run_options,
114  const RegressionRequest& request,
115  RegressionResponse* response) {
116  TRACELITERAL("TfrtSavedModelServable::Regress");
117  auto recorder = CreateRecorder();
118  return RunRegress(GetTFRTSavedModelRunOptions(run_options), version(),
119  saved_model_.get(), request, response);
120 }
121 
122 absl::Status TfrtSavedModelServable::Predict(const RunOptions& run_options,
123  const PredictRequest& request,
124  PredictResponse* response) {
125  TRACELITERAL("TfrtSavedModelServable::Predict");
126  auto recorder = CreateRecorder();
127  return internal::RunPredict(
128  GetTFRTSavedModelRunOptions(run_options), version(),
129  predict_response_tensor_serialization_option_, saved_model_.get(),
130  request, response,
131  thread_pool_factory_ == nullptr
132  ? tsl::thread::ThreadPoolOptions()
133  : thread_pool_factory_->GetThreadPools().get());
134 }
135 
136 // TODO(b/288096487): Add a unit test once we have the streaming model in OSS.
137 absl::StatusOr<std::unique_ptr<PredictStreamedContext>>
138 TfrtSavedModelServable::PredictStreamed(
139  const RunOptions& run_options,
140  absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
141  response_callback) {
142  auto recorder = CreateRecorder();
143  return std::make_unique<SingleRequestPredictStreamedContext>(
144  [this, run_options, response_callback = std::move(response_callback)](
145  const PredictRequest& request) mutable -> absl::Status {
146  TRACELITERAL("TfrtSavedModelServable::PredictStreamed");
147 
148  auto tfrt_run_options = GetTFRTSavedModelRunOptions(run_options);
149 
150  std::string signature_name =
151  request.model_spec().signature_name().empty()
152  ? kDefaultServingSignatureDefKey
153  : request.model_spec().signature_name();
154 
155  tensorflow::serving::ModelSpec model_spec = request.model_spec();
156  model_spec.set_signature_name(signature_name);
157  model_spec.mutable_version()->set_value(version());
158 
159  tfrt_run_options.streamed_output_callback =
160  [&](absl::flat_hash_map<std::string, tensorflow::Tensor> outputs) {
161  tensorflow::serving::PredictResponse response;
162  *response.mutable_model_spec() = model_spec;
163 
164  for (const auto& [output_key, output_tensor] : outputs) {
165  tensorflow::TensorProto& tensor_proto =
166  (*response.mutable_outputs())[output_key];
167 
168  // TODO(b/288096487): We are assuming
169  // predict_response_tensor_serialization_option_ ==
170  // kAsProtoField. The proper way is to check serialize based on
171  // the value of predict_response_tensor_serialization_option_.
172  output_tensor.AsProtoField(&tensor_proto);
173  }
174 
175  response_callback(std::move(response));
176  // TODO(b/288096487): Add streamz support.
177  };
178 
179  // The actual responses are passed through `response_callback`. The
180  // graph should have no output tensors currently.
181  PredictResponse response;
182 
183  return internal::RunPredict(
184  tfrt_run_options, version(),
185  predict_response_tensor_serialization_option_, saved_model_.get(),
186  request, &response,
187  thread_pool_factory_ == nullptr
188  ? tsl::thread::ThreadPoolOptions()
189  : thread_pool_factory_->GetThreadPools().get());
190  });
191 }
192 
193 absl::Status TfrtSavedModelServable::MultiInference(
194  const RunOptions& run_options, const MultiInferenceRequest& request,
195  MultiInferenceResponse* response) {
196  TRACELITERAL("TfrtSavedModelServable::MultiInference");
197  auto recorder = CreateRecorder();
198  return RunMultiInference(GetTFRTSavedModelRunOptions(run_options), version(),
199  saved_model_.get(), request, response);
200 }
201 
202 absl::Status TfrtSavedModelServable::Suspend() {
203  TRACELITERAL("TfrtSavedModelServable::Suspend");
204  absl::MutexLock lock(&paging_mu_);
205  if (!suspend_fn_) {
206  return absl::UnimplementedError("Suspend is not implemented");
207  }
208  if (suspended_) {
209  return absl::OkStatus();
210  }
211  absl::Status status = suspend_fn_(this);
212  if (status.ok()) {
213  suspended_ = true;
214  }
215  return status;
216 }
217 
218 absl::Status TfrtSavedModelServable::Resume() {
219  TRACELITERAL("TfrtSavedModelServable::Resume");
220  absl::MutexLock lock(&paging_mu_);
221  if (!resume_fn_) {
222  return absl::UnimplementedError("Resume is not implemented");
223  }
224  if (!suspended_) {
225  return absl::OkStatus();
226  }
227  absl::Status status = resume_fn_(this);
228  if (!status.ok()) {
229  suspended_ = false;
230  }
231  return status;
232 }
233 
234 namespace {
235 
236 absl::Status ValidateGetModelMetadataRequest(
237  const GetModelMetadataRequest& request) {
238  for (const auto& metadata_field : request.metadata_field()) {
239  if (metadata_field != kSignatureDef) {
240  return absl::InvalidArgumentError(
241  absl::StrCat("Metadata field ", metadata_field, " is not supported"));
242  }
243  }
244  return absl::OkStatus();
245 }
246 
247 } // namespace
248 
249 RequestRecorder::~RequestRecorder() = default;
250 
251 absl::Status TfrtSavedModelServable::GetModelMetadata(
252  const GetModelMetadataRequest& request,
253  GetModelMetadataResponse* response) {
254  TRACELITERAL("TfrtSavedModelServable::GetModelMetadata");
255 
256  TF_RETURN_IF_ERROR(ValidateGetModelMetadataRequest(request));
257 
258  for (const auto& metadata_field : request.metadata_field()) {
259  if (metadata_field == kSignatureDef) {
260  SignatureDefMap signature_def_map;
261  for (const auto& signature :
262  saved_model_->GetMetaGraphDef().signature_def()) {
263  (*signature_def_map.mutable_signature_def())[signature.first] =
264  signature.second;
265  }
266 
267  auto* response_model_spec = response->mutable_model_spec();
268  response_model_spec->set_name(std::string(name()));
269  response_model_spec->mutable_version()->set_value(version());
270  (*response->mutable_metadata())[kSignatureDef].PackFrom(
271  signature_def_map);
272  } else {
273  return absl::InvalidArgumentError(
274  absl::StrCat("MetadataField ", metadata_field, " is not supported"));
275  }
276  }
277 
278  return absl::OkStatus();
279 }
280 
281 absl::StatusOr<std::unique_ptr<TfrtSavedModelServable>>
282 CreateTfrtSavedModelServable(
283  const tensorflow::tfrt_stub::SavedModel::Options& options,
284  absl::string_view name, int64_t version, absl::string_view saved_model_dir,
285  absl::flat_hash_set<std::string> tags) {
286  TF_ASSIGN_OR_RETURN(
287  auto saved_model,
288  tensorflow::tfrt_stub::SavedModelImpl::LoadSavedModel(
289  options, saved_model_dir,
290  std::unordered_set<std::string>(tags.begin(), tags.end())));
291 
292  TF_ASSIGN_OR_RETURN(
293  auto saved_model_config,
294  LoadSavedModelConfigOrDefault(std::string(saved_model_dir)));
295 
296  TfrtSavedModelConfig config;
297  return std::make_unique<TfrtSavedModelServable>(
298  name, version, config, saved_model_config, std::move(saved_model),
299  /*thread_pool_factory=*/nullptr);
300 }
301 
302 } // namespace serving
303 } // namespace tensorflow