TensorFlow Serving C++ API Documentation
source_adapter.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_CORE_SOURCE_ADAPTER_H_
17 #define TENSORFLOW_SERVING_CORE_SOURCE_ADAPTER_H_
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow_serving/core/loader.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/source.h"
31 #include "tensorflow_serving/core/storage_path.h"
32 #include "tensorflow_serving/core/target.h"
33 #include "tensorflow_serving/util/class_registration.h"
34 
35 namespace tensorflow {
36 namespace serving {
37 
58 template <typename InputType, typename OutputType>
59 class SourceAdapter : public TargetBase<InputType>, public Source<OutputType> {
60  public:
61  ~SourceAdapter() override = 0;
62 
65  void SetAspiredVersions(const StringPiece servable_name,
66  std::vector<ServableData<InputType>> versions) final;
67 
68  void SetAspiredVersionsCallback(
69  typename Source<OutputType>::AspiredVersionsCallback callback) final;
70 
73  virtual std::vector<ServableData<OutputType>> Adapt(
74  const StringPiece servable_name,
75  std::vector<ServableData<InputType>> versions) = 0;
76 
79 
80  protected:
81  // This is an abstract class.
82  SourceAdapter() = default;
83 
84  private:
85  // The callback for emitting OutputType-based aspired-version lists.
86  typename Source<OutputType>::AspiredVersionsCallback outgoing_callback_;
87 
88  // Has 'outgoing_callback_' been set yet, so that the SourceAdapter is ready
89  // to propagate aspired versions?
90  Notification outgoing_callback_set_;
91 };
92 
93 // START_SKIP_DOXYGEN
94 
95 // Define a SourceAdapter registry for the common case of adapting from a
96 // storage path to a loader.
99 DEFINE_CLASS_REGISTRY(StoragePathSourceAdapterRegistry,
101 #define REGISTER_STORAGE_PATH_SOURCE_ADAPTER(ClassCreator, ConfigProto) \
102  REGISTER_CLASS(StoragePathSourceAdapterRegistry, StoragePathSourceAdapter, \
103  ClassCreator, ConfigProto);
104 
105 // A source adapter that converts InputType instances to OutputType instances
106 // one at a time (i.e. there is no interaction among members of a given aspired-
107 // version list). Most source adapters can subclass UnarySourceAdapter, and do
108 // not need the full generality of SourceAdapter.
109 //
110 // Requires OutputType to be default-constructable and updatable in-place.
111 //
112 // Implementing subclasses supply an implementation of the Convert() virtual
113 // method, which converts a servable from InputType to OutputType.
114 //
115 // IMPORTANT: Every leaf derived class must call Detach() at the top of its
116 // destructor. (See documentation on TargetBase::Detach() in target.h.) Doing so
117 // ensures that no Convert() calls are in flight during destruction of member
118 // variables.
119 template <typename InputType, typename OutputType>
120 class UnarySourceAdapter : public SourceAdapter<InputType, OutputType> {
121  public:
122  ~UnarySourceAdapter() override = 0;
123 
124  protected:
125  // This is an abstract class.
126  UnarySourceAdapter() = default;
127 
128  private:
129  // This method is implemented in terms of Convert(), which the implementing
130  // subclass must supply.
131  std::vector<ServableData<OutputType>> Adapt(
132  const StringPiece servable_name,
133  std::vector<ServableData<InputType>> versions) final;
134 
135  // Converts a single InputType instance into a corresponding OutputType
136  // instance.
137  virtual Status Convert(const InputType& data, OutputType* converted_data) = 0;
138 };
139 
140 // A source adapter that converts every incoming ServableData<InputType> item
141 // into an error-containing ServableData<OutputType>. If the incoming data item
142 // was already an error, the existing error is passed through; otherwise a new
143 // error Status given via this class's constructor is added.
144 //
145 // This class is useful in conjunction with a router, to handle servable data
146 // items that do not conform to any explicitly-programmed route. Specifically,
147 // consider a fruit router configured route apples to output port 0, oranges to
148 // output port 1, and anything else to a final port 2. If we only have proper
149 // SourceAdapters to handle apples and oranges, we might connect an
150 // ErrorInjectingSourceAdapter to port 2, to catch any unexpected fruits.
151 template <typename InputType, typename OutputType>
153  : public SourceAdapter<InputType, OutputType> {
154  public:
155  explicit ErrorInjectingSourceAdapter(const Status& error);
156  ~ErrorInjectingSourceAdapter() override;
157 
158  private:
159  std::vector<ServableData<OutputType>> Adapt(
160  const StringPiece servable_name,
161  std::vector<ServableData<InputType>> versions) override;
162 
163  // The error status to inject.
164  const Status error_;
165 
166  TF_DISALLOW_COPY_AND_ASSIGN(ErrorInjectingSourceAdapter);
167 };
168 
170 // Implementation details follow. API users need not read.
171 
172 template <typename InputType, typename OutputType>
174 
175 template <typename InputType, typename OutputType>
177  const StringPiece servable_name,
178  std::vector<ServableData<InputType>> versions) {
179  outgoing_callback_set_.WaitForNotification();
180  outgoing_callback_(servable_name, Adapt(servable_name, std::move(versions)));
181 }
182 
183 template <typename InputType, typename OutputType>
186  outgoing_callback_ = callback;
187  outgoing_callback_set_.Notify();
188 }
189 
190 template <typename InputType, typename OutputType>
192  ServableData<InputType> input) {
193  const StringPiece servable_name(input.id().name);
194  std::vector<ServableData<InputType>> input_versions = {input};
195  std::vector<ServableData<OutputType>> output_versions =
196  Adapt(servable_name, input_versions);
197  DCHECK_EQ(1, output_versions.size());
198  return std::move(output_versions[0]);
199 }
200 
201 template <typename InputType, typename OutputType>
203 
204 template <typename InputType, typename OutputType>
205 std::vector<ServableData<OutputType>>
206 UnarySourceAdapter<InputType, OutputType>::Adapt(
207  const StringPiece servable_name,
208  std::vector<ServableData<InputType>> versions) {
209  std::vector<ServableData<OutputType>> adapted_versions;
210  for (const ServableData<InputType>& version : versions) {
211  if (version.status().ok()) {
212  OutputType adapted_data;
213  Status adapt_status = Convert(version.DataOrDie(), &adapted_data);
214  if (adapt_status.ok()) {
215  adapted_versions.emplace_back(
216  ServableData<OutputType>{version.id(), std::move(adapted_data)});
217  } else {
218  adapted_versions.emplace_back(
219  ServableData<OutputType>{version.id(), adapt_status});
220  }
221  } else {
222  adapted_versions.emplace_back(
223  ServableData<OutputType>{version.id(), version.status()});
224  }
225  }
226  return adapted_versions;
227 }
228 
229 template <typename InputType, typename OutputType>
230 ErrorInjectingSourceAdapter<InputType, OutputType>::ErrorInjectingSourceAdapter(
231  const Status& error)
232  : error_(error) {
233  DCHECK(!error.ok());
234 }
235 
236 template <typename InputType, typename OutputType>
237 ErrorInjectingSourceAdapter<InputType,
238  OutputType>::~ErrorInjectingSourceAdapter() {
239  TargetBase<InputType>::Detach();
240 }
241 
242 template <typename InputType, typename OutputType>
243 std::vector<ServableData<OutputType>>
244 ErrorInjectingSourceAdapter<InputType, OutputType>::Adapt(
245  const StringPiece servable_name,
246  std::vector<ServableData<InputType>> versions) {
247  std::vector<ServableData<OutputType>> adapted_versions;
248  for (const ServableData<InputType>& version : versions) {
249  if (version.status().ok()) {
250  LOG(INFO) << "Injecting error for servable " << version.id() << ": "
251  << error_.message();
252  adapted_versions.emplace_back(
253  ServableData<OutputType>{version.id(), error_});
254  } else {
255  adapted_versions.emplace_back(
256  ServableData<OutputType>{version.id(), version.status()});
257  }
258  }
259  return adapted_versions;
260 }
261 
262 // END_SKIP_DOXYGEN
263 
264 } // namespace serving
265 } // namespace tensorflow
266 
267 #endif // TENSORFLOW_SERVING_CORE_SOURCE_ADAPTER_H_
virtual std::vector< ServableData< OutputType > > Adapt(const StringPiece servable_name, std::vector< ServableData< InputType >> versions)=0
void SetAspiredVersions(const StringPiece servable_name, std::vector< ServableData< InputType >> versions) final
ServableData< OutputType > AdaptOneVersion(ServableData< InputType > input)
Adapts a single servable data item. (Implemented on top of Adapt().)