TensorFlow Serving C++ API Documentation
tfrt_servable.h
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SERVABLE_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SERVABLE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/base/thread_annotations.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/functional/any_invocable.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/synchronization/mutex.h"
31 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
32 #include "tensorflow_serving/apis/classification.pb.h"
33 #include "tensorflow_serving/apis/get_model_metadata.pb.h"
34 #include "tensorflow_serving/apis/inference.pb.h"
35 #include "tensorflow_serving/apis/predict.pb.h"
36 #include "tensorflow_serving/apis/regression.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
38 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
39 #include "tensorflow_serving/servables/tensorflow/servable.h"
40 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 
46 // The RequestRecorder interface for implementations to inject custom metric and
47 // cost reporting.
49  public:
50  virtual ~RequestRecorder();
51 };
52 
53 // Implements PredictionService`-like interface for a single SavedModel based on
54 // `tensorflow::tfrt_stub::SavedModel`. Executables are lazily compiled on its
55 // first use and cached. This class is thread-safe.
57  public:
58  TfrtSavedModelServable(absl::string_view name, int64_t version,
59  const TfrtSavedModelConfig& config,
60  const SavedModelConfig& model_config,
61  std::unique_ptr<tfrt_stub::SavedModel> saved_model,
62  ThreadPoolFactory* thread_pool_factory)
64  name, version, config, model_config, std::move(saved_model),
65  thread_pool_factory,
66  [](TfrtSavedModelServable&) { return nullptr; }) {}
67 
69  absl::string_view name, int64_t version,
70  const TfrtSavedModelConfig& config, const SavedModelConfig& model_config,
71  std::unique_ptr<tfrt_stub::SavedModel> saved_model,
72  ThreadPoolFactory* thread_pool_factory,
73  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
74  recorder_creator);
75 
76  absl::Status Classify(const RunOptions& run_options,
77  const ClassificationRequest& request,
78  ClassificationResponse* response) override;
79 
80  absl::Status Regress(const RunOptions& run_options,
81  const RegressionRequest& request,
82  RegressionResponse* response) override;
83 
84  absl::Status Predict(const RunOptions& run_options,
85  const PredictRequest& request,
86  PredictResponse* response) override;
87 
88  absl::StatusOr<std::unique_ptr<PredictStreamedContext>> PredictStreamed(
89  const RunOptions& run_options,
90  absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
91  response_callback) override;
92 
93  absl::Status MultiInference(const RunOptions& run_options,
94  const MultiInferenceRequest& request,
95  MultiInferenceResponse* response) override;
96 
97  absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
98  GetModelMetadataResponse* response) override;
99 
100  bool SupportsPaging() const override { return true; }
101 
102  absl::Status Suspend() override;
103 
104  absl::Status Resume() override;
105 
106  tfrt_stub::SavedModel& saved_model() const { return *saved_model_; }
107 
108  void set_resume_fn(
109  absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> resume_fn) {
110  absl::MutexLock lock(&paging_mu_);
111  resume_fn_ = std::move(resume_fn);
112  }
113 
114  void set_suspend_fn(
115  absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> suspend_fn) {
116  absl::MutexLock lock(&paging_mu_);
117  suspend_fn_ = std::move(suspend_fn);
118  }
119 
120  private:
121  tfrt_stub::SavedModel::RunOptions GetTFRTSavedModelRunOptions(
122  const Servable::RunOptions& run_options) const;
123 
124  std::unique_ptr<RequestRecorder> CreateRecorder() {
125  return recorder_creator_(*this);
126  }
127 
128  std::unique_ptr<tfrt_stub::SavedModel> saved_model_;
129 
130  // `config_` is the adapter config, and it is the same for all
131  // TfrtSavedModelServables within a model server.
132  TfrtSavedModelConfig config_;
133 
134  internal::PredictResponseTensorSerializationOption
135  predict_response_tensor_serialization_option_ =
136  internal::PredictResponseTensorSerializationOption::kAsProtoField;
137 
138  // `thread_pool_factory_` is not owned by Servables. In a typical
139  // implementation, the factory will own the `thread_pool_factory_` and it will
140  // be shared across different Servables.
141  ThreadPoolFactory* thread_pool_factory_ = nullptr;
142 
143  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
144  recorder_creator_ = [](TfrtSavedModelServable&) { return nullptr; };
145 
146  absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> suspend_fn_
147  ABSL_GUARDED_BY(paging_mu_);
148 
149  absl::AnyInvocable<absl::Status(TfrtSavedModelServable*)> resume_fn_
150  ABSL_GUARDED_BY(paging_mu_);
151 
152  bool suspended_ ABSL_GUARDED_BY(paging_mu_) = false;
153 
154  absl::Mutex paging_mu_;
155 };
156 
157 // Creates a TfrtSavedModelServable from `saved_model_dir`.
158 absl::StatusOr<std::unique_ptr<TfrtSavedModelServable>>
159 CreateTfrtSavedModelServable(
160  const tensorflow::tfrt_stub::SavedModel::Options& options,
161  absl::string_view name, int64_t version, absl::string_view saved_model_dir,
162  absl::flat_hash_set<std::string> tags);
163 
164 } // namespace serving
165 } // namespace tensorflow
166 
167 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SERVABLE_H_