TensorFlow Serving C++ API Documentation
servable.h
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_
18 
19 #include <stdint.h>
20 
21 #include <memory>
22 #include <string>
23 
24 #include "absl/functional/any_invocable.h"
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/protobuf/config.pb.h"
29 #include "tensorflow_serving/apis/classification.pb.h"
30 #include "tensorflow_serving/apis/get_model_metadata.pb.h"
31 #include "tensorflow_serving/apis/inference.pb.h"
32 #include "tensorflow_serving/apis/predict.pb.h"
33 #include "tensorflow_serving/apis/regression.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/run_options.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 
39 inline constexpr absl::string_view kSignatureDef = "signature_def";
40 
41 // Context of a `PredictStreamed` session. The caller of `PredictStreamed` calls
42 // `ProcessRequest` every time a request becomes available. The caller must call
43 // `Close()` at the end of the session before deleting the context object.
44 //
45 // The implementation can be thread-compatible. The caller is responsible for
46 // synchronizing all method invocations.
48  public:
49  virtual ~PredictStreamedContext() = default;
50 
51  // Consumes one incoming request. Blocking here may delay the consumption of
52  // subsequent requests.
53  virtual absl::Status ProcessRequest(const PredictRequest& request) = 0;
54 
55  // Closes the `PredictStreamed` session.
56  virtual absl::Status Close() = 0;
57 };
58 
59 // A convenience wrapper for cases where the implementation allows exactly one
60 // request. `f` takes this single request and produces responses by calling the
61 // `response_callback` passed to `Servable::PredictStreamed`.
62 //
63 // This implementation is thread compatible but not thread safe.
65  : public PredictStreamedContext {
66  public:
68  absl::AnyInvocable<absl::Status(const PredictRequest&)> f);
69 
70  absl::Status ProcessRequest(const PredictRequest& request) final;
71  absl::Status Close() final;
72 
73  private:
74  absl::AnyInvocable<absl::Status(const PredictRequest&)> f_;
75  bool one_request_received_ = false;
76 };
77 
78 // Provides a `PredictionService`-like interface. All concrete implementations
79 // are expected to be thread-safe.
80 class Servable {
81  public:
82  Servable(absl::string_view name, int64_t version, bool is_critical = false)
83  : name_(std::string(name)),
84  version_(version),
85  is_critical_(is_critical) {}
86 
87  virtual ~Servable() = default;
88 
89  // Returns the name associated with this servable.
90  absl::string_view name() const { return name_; }
91 
92  // Returns the version associated with this servable.
93  int64_t version() const { return version_; }
94 
95  bool IsCritical() const { return is_critical_; }
96 
98 
99  virtual absl::Status Classify(const RunOptions& run_options,
100  const ClassificationRequest& request,
101  ClassificationResponse* response) = 0;
102 
103  virtual absl::Status Regress(const RunOptions& run_options,
104  const RegressionRequest& request,
105  RegressionResponse* response) = 0;
106 
107  virtual absl::Status Predict(const RunOptions& run_options,
108  const PredictRequest& request,
109  PredictResponse* response) = 0;
110 
111  // Bidirectional streamed version of `Predict`. Returns a "context" object
112  // that allows the caller to pass requests incrementally. The servable is kept
113  // alive until the context object is deleted.
114  //
115  // `response_callback` is called for each streamed output, zero or more times,
116  // when the streamed output becomes available. If an error is returned for any
117  // response, subsequent responses and requests will be ignored and the error
118  // will be returned. The callback invocation must be serialized by the
119  // implementation, so that `response_callback` does not have to be
120  // thread-safe, but blocking inside the callback may cause the next callback
121  // invocation to be delayed. The implementation must guarantee that the
122  // callback is never called after the `PredictStreamed` method returns.
123  virtual absl::StatusOr<std::unique_ptr<PredictStreamedContext>>
124  PredictStreamed(const RunOptions& run_options,
125  absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
126  response_callback) = 0;
127 
128  virtual absl::Status MultiInference(const RunOptions& run_options,
129  const MultiInferenceRequest& request,
130  MultiInferenceResponse* response) = 0;
131 
132  virtual absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
133  GetModelMetadataResponse* response) = 0;
134 
135  // Returns true iff this servable supports paging.
136  //
137  // Paging is a process of moving model data (i.e., variables and executables)
138  // between devices' HBM and host RAM. Servables that support paging can
139  // time-share the available HBM and be paged in and out of the HBM according
140  // to a paging policy.
141  //
142  // Note that even if a Servable supports paging, it is up to a Server
143  // implementation to make active (or any!) use of the paging functionality.
144  virtual bool SupportsPaging() const;
145 
146  // Pages out all variables and executables owned by this servable from
147  // devices' HBM to host RAM.
148  //
149  // After this method returns, all requests return an error until `Resume()` is
150  // called to bring the states back to device memory.
151  //
152  // If the suspension fails, the model is in an unspecified state and must be
153  // unloaded and loaded again for it to be useful.
154  //
155  // This method may only be invoked if SupportsPaging() returns true.
156  virtual absl::Status Suspend();
157 
158  // Inverse of `Suspend()`. Synchronously pages in all variables and
159  // executables owned by this servable back to devices' HBM.
160  //
161  // Returns an error if the servable is not in a suspended state or resumption
162  // failed. If the resumption fails, the model is in an unspecified state and
163  // must be unloaded and loaded again for it to be useful.
164  //
165  // This method may only be invoked if SupportsPaging() returns true.
166  virtual absl::Status Resume();
167 
168  private:
169  // Metadata of this servable. Currently matches the fields in
170  // `ServableId`.
171  const std::string name_;
172  const int64_t version_;
173  const bool is_critical_;
174 };
175 
176 // An "empty" servable where there's no model associated with the servable. All
177 // methods will return an error.
178 //
179 // Empty servables can be used in places where a servable is expected but we
180 // don't need to load any models. For example, Model Server currently expects
181 // each task to have at least one servable loaded, but Pathways Serving requires
182 // only the controller task to initiate loading servables. So we use empty
183 // servables in non-zero tasks to make sure non-zero tasks don't load anything.
184 class EmptyServable : public Servable {
185  public:
186  EmptyServable();
187 
188  absl::Status Classify(const RunOptions& run_options,
189  const ClassificationRequest& request,
190  ClassificationResponse* response) override {
191  return error_;
192  }
193 
194  absl::Status Regress(const RunOptions& run_options,
195  const RegressionRequest& request,
196  RegressionResponse* response) override {
197  return error_;
198  }
199 
200  absl::Status Predict(const RunOptions& run_options,
201  const PredictRequest& request,
202  PredictResponse* response) override {
203  return error_;
204  }
205 
206  absl::StatusOr<std::unique_ptr<PredictStreamedContext>> PredictStreamed(
207  const RunOptions& run_options,
208  absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
209  response_callback) {
210  return error_;
211  }
212 
213  absl::Status MultiInference(const RunOptions& run_options,
214  const MultiInferenceRequest& request,
215  MultiInferenceResponse* response) override {
216  return error_;
217  }
218 
219  absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
220  GetModelMetadataResponse* response) override {
221  return error_;
222  }
223 
224  private:
225  absl::Status error_;
226 };
227 
228 } // namespace serving
229 } // namespace tensorflow
230 
231 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_