16 #ifndef TENSORFLOW_SERVING_CORE_SOURCE_ADAPTER_H_
17 #define TENSORFLOW_SERVING_CORE_SOURCE_ADAPTER_H_
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow_serving/core/loader.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/source.h"
31 #include "tensorflow_serving/core/storage_path.h"
32 #include "tensorflow_serving/core/target.h"
33 #include "tensorflow_serving/util/class_registration.h"
35 namespace tensorflow {
58 template <
typename InputType,
typename OutputType>
68 void SetAspiredVersionsCallback(
73 virtual std::vector<ServableData<OutputType>>
Adapt(
74 const StringPiece servable_name,
90 Notification outgoing_callback_set_;
99 DEFINE_CLASS_REGISTRY(StoragePathSourceAdapterRegistry,
101 #define REGISTER_STORAGE_PATH_SOURCE_ADAPTER(ClassCreator, ConfigProto) \
102 REGISTER_CLASS(StoragePathSourceAdapterRegistry, StoragePathSourceAdapter, \
103 ClassCreator, ConfigProto);
119 template <
typename InputType,
typename OutputType>
131 std::vector<ServableData<OutputType>> Adapt(
132 const StringPiece servable_name,
137 virtual Status Convert(
const InputType& data, OutputType* converted_data) = 0;
151 template <
typename InputType,
typename OutputType>
159 std::vector<ServableData<OutputType>> Adapt(
160 const StringPiece servable_name,
172 template <
typename InputType,
typename OutputType>
175 template <
typename InputType,
typename OutputType>
177 const StringPiece servable_name,
179 outgoing_callback_set_.WaitForNotification();
180 outgoing_callback_(servable_name, Adapt(servable_name, std::move(versions)));
183 template <
typename InputType,
typename OutputType>
186 outgoing_callback_ = callback;
187 outgoing_callback_set_.Notify();
190 template <
typename InputType,
typename OutputType>
193 const StringPiece servable_name(input.id().name);
194 std::vector<ServableData<InputType>> input_versions = {input};
195 std::vector<ServableData<OutputType>> output_versions =
196 Adapt(servable_name, input_versions);
197 DCHECK_EQ(1, output_versions.size());
198 return std::move(output_versions[0]);
201 template <
typename InputType,
typename OutputType>
204 template <
typename InputType,
typename OutputType>
205 std::vector<ServableData<OutputType>>
206 UnarySourceAdapter<InputType, OutputType>::Adapt(
207 const StringPiece servable_name,
208 std::vector<ServableData<InputType>> versions) {
209 std::vector<ServableData<OutputType>> adapted_versions;
210 for (
const ServableData<InputType>& version : versions) {
211 if (version.status().ok()) {
212 OutputType adapted_data;
213 Status adapt_status = Convert(version.DataOrDie(), &adapted_data);
214 if (adapt_status.ok()) {
215 adapted_versions.emplace_back(
216 ServableData<OutputType>{version.id(), std::move(adapted_data)});
218 adapted_versions.emplace_back(
219 ServableData<OutputType>{version.id(), adapt_status});
222 adapted_versions.emplace_back(
223 ServableData<OutputType>{version.id(), version.status()});
226 return adapted_versions;
229 template <
typename InputType,
typename OutputType>
230 ErrorInjectingSourceAdapter<InputType, OutputType>::ErrorInjectingSourceAdapter(
236 template <
typename InputType,
typename OutputType>
237 ErrorInjectingSourceAdapter<InputType,
238 OutputType>::~ErrorInjectingSourceAdapter() {
239 TargetBase<InputType>::Detach();
242 template <
typename InputType,
typename OutputType>
243 std::vector<ServableData<OutputType>>
244 ErrorInjectingSourceAdapter<InputType, OutputType>::Adapt(
245 const StringPiece servable_name,
246 std::vector<ServableData<InputType>> versions) {
247 std::vector<ServableData<OutputType>> adapted_versions;
248 for (
const ServableData<InputType>& version : versions) {
249 if (version.status().ok()) {
250 LOG(INFO) <<
"Injecting error for servable " << version.id() <<
": "
252 adapted_versions.emplace_back(
253 ServableData<OutputType>{version.id(), error_});
255 adapted_versions.emplace_back(
256 ServableData<OutputType>{version.id(), version.status()});
259 return adapted_versions;
virtual std::vector< ServableData< OutputType > > Adapt(const StringPiece servable_name, std::vector< ServableData< InputType >> versions)=0
void SetAspiredVersions(const StringPiece servable_name, std::vector< ServableData< InputType >> versions) final
ServableData< OutputType > AdaptOneVersion(ServableData< InputType > input)
Adapts a single servable data item. (Implemented on top of Adapt().)