TensorFlow Serving C++ API Documentation
source_router.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_CORE_SOURCE_ROUTER_H_
17 #define TENSORFLOW_SERVING_CORE_SOURCE_ROUTER_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow_serving/core/source.h"
25 #include "tensorflow_serving/core/source_adapter.h"
26 #include "tensorflow_serving/core/target.h"
27 
28 namespace tensorflow {
29 namespace serving {
30 
31 // A module that splits aspired-version calls from one input to multiple outputs
32 // based on some criterion, e.g. the servable name. There is a single input
33 // "port" represented by Target<T>::SetAspiredVersions(), and N output "ports",
34 // each of type Source<T>, numbered 0, 1, 2, ...
35 //
36 // For a typical use-case, consider a server hosting multiple kinds of servables
37 // (say, apple servables and orange servables). Perhaps both kinds of servables
38 // arrive via file-system paths, a la:
39 // /path/to/some/apple/servable
40 // /path/to/some/orange/servable
41 // where the servable kinds are distinguished based on the presence of "apple"
42 // or "orange" in the path. A SourceRouter can be interposed between a file-
43 // system monitoring Source<StoragePath>, and a pair of SourceAdapters (one that
44 // emits loaders of apple servables, and one that emits loaders of orange
45 // servables), to route each path to the appropriate SourceAdapter.
46 //
47 // IMPORTANT: Every leaf derived class must call Detach() at the top of its
48 // destructor. (See documentation on TargetBase::Detach() in target.h.) Doing so
49 // ensures that no virtual method calls are in flight during destruction of
50 // member variables.
51 template <typename T>
52 class SourceRouter : public TargetBase<T> {
53  public:
54  static constexpr int kNoRoute = -1;
55 
56  ~SourceRouter() override = 0;
57 
58  // Returns a vector of N source pointers, corresponding to the N output ports
59  // of the router. The caller must invoke ConnectSourceToTarget() (or directly
60  // call SetAspiredVersionsCallback()) on each of them to arrange to route
61  // items to various upstream targets. That must be done exactly once, and
62  // before calling SetAspiredVersions() on the router.
63  std::vector<Source<T>*> GetOutputPorts();
64 
65  // Implemented in terms of Route(), defined below and written by the subclass.
66  void SetAspiredVersions(const StringPiece servable_name,
67  std::vector<ServableData<T>> versions) final;
68 
69  protected:
70  // This is an abstract class.
71  SourceRouter() = default;
72 
73  // Returns the number of output ports. Must be > 0 and fixed for the lifetime
74  // of the router. To be written by the implementing subclass.
75  virtual int num_output_ports() const = 0;
76 
77  // Returns `kNoRoute` or a valid output port # in [0, num_output_ports() - 1].
78  // Aspired-versions requests will be routed to the output port corresponding
79  // to the returned port number. If `kNoRoute` is returned, the aspired-version
80  // request will be discarded silently. To be written by the implementing
81  // subclass.
82  virtual int Route(const StringPiece servable_name,
83  const std::vector<ServableData<T>>& versions) = 0;
84 
85  private:
86  // The num_output_ports() output ports. Each one is an IdentitySourceAdapter.
87  // Populated in GetOutputPorts().
88  std::vector<std::unique_ptr<SourceAdapter<T, T>>> output_ports_;
89 
90  // Has 'output_ports_' been populated yet, so that the SourceAdapter is ready
91  // to propagate aspired versions?
92  Notification output_ports_created_;
93 };
94 
96 // Implementation details follow. API users need not read.
97 
98 namespace internal {
99 
100 // A SourceAdapter that passes through data unchanged. Used to implement the
101 // output ports.
102 template <typename T>
103 class IdentitySourceAdapter final : public SourceAdapter<T, T> {
104  public:
105  IdentitySourceAdapter() = default;
107 
108  private:
109  std::vector<ServableData<T>> Adapt(
110  const StringPiece servable_name,
111  std::vector<ServableData<T>> versions) final;
112 
113  TF_DISALLOW_COPY_AND_ASSIGN(IdentitySourceAdapter);
114 };
115 
116 template <typename T>
117 std::vector<ServableData<T>> IdentitySourceAdapter<T>::Adapt(
118  const StringPiece servable_name, std::vector<ServableData<T>> versions) {
119  return versions;
120 }
121 
122 } // namespace internal
123 
124 template <typename T>
125 SourceRouter<T>::~SourceRouter() {}
126 
127 template <typename T>
128 std::vector<Source<T>*> SourceRouter<T>::GetOutputPorts() {
129  if (!output_ports_created_.HasBeenNotified()) {
130  int num_ports = num_output_ports();
131  if (num_ports < 1) {
132  LOG(ERROR) << "SourceRouter abstraction used improperly; "
133  "num_output_ports() must return a number greater than 0";
134  DCHECK(false);
135  num_ports = 1;
136  }
137  for (int i = 0; i < num_ports; ++i) {
138  output_ports_.emplace_back(new internal::IdentitySourceAdapter<T>);
139  }
140  output_ports_created_.Notify();
141  }
142 
143  std::vector<Source<T>*> result;
144  for (auto& output_port : output_ports_) {
145  result.push_back(output_port.get());
146  }
147  return result;
148 }
149 
150 template <typename T>
151 void SourceRouter<T>::SetAspiredVersions(
152  const StringPiece servable_name, std::vector<ServableData<T>> versions) {
153  output_ports_created_.WaitForNotification();
154  int output_port = Route(servable_name, versions);
155  if (output_port == kNoRoute) {
156  return;
157  }
158  if (output_port < 0 || output_port > output_ports_.size() - 1) {
159  LOG(ERROR)
160  << "SourceRouter abstraction used improperly; Route() must return "
161  "kNoRoute or a value in [0, num_output_ports()-1]; suppressing the "
162  "aspired-versions request";
163  DCHECK(false);
164  return;
165  }
166  output_ports_[output_port]->SetAspiredVersions(servable_name,
167  std::move(versions));
168 }
169 
170 } // namespace serving
171 } // namespace tensorflow
172 
173 #endif // TENSORFLOW_SERVING_CORE_SOURCE_ROUTER_H_