TensorFlow Serving C++ API Documentation
target.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_CORE_TARGET_H_
17 #define TENSORFLOW_SERVING_CORE_TARGET_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow_serving/core/servable_data.h"
28 #include "tensorflow_serving/core/source.h"
29 #include "tensorflow_serving/util/observer.h"
30 
31 namespace tensorflow {
32 namespace serving {
33 
34 // An abstraction for a module that receives instructions on servables to load,
35 // from a Source. See source.h for documentation.
36 template <typename T>
37 class Target {
38  public:
39  virtual ~Target() = default;
40 
41  // Supplies a callback for a Source to use to supply aspired versions. See
42  // Source<T>::AspiredVersionsCallback for the semantics of aspired versions.
43  //
44  // The returned function satisfies these properties:
45  // - It is thread-safe.
46  // - It is valid forever, even after this Target object has been destroyed.
47  // After this Target is gone, the function becomes a no-op.
48  // - It blocks until the target has been fully set up and is able to handle
49  // the incoming request.
50  virtual typename Source<T>::AspiredVersionsCallback
51  GetAspiredVersionsCallback() = 0;
52 };
53 
54 // A base class for Target implementations. Takes care of ensuring that the
55 // emitted aspired-versions callbacks outlive the Target object. Target
56 // implementations should extend TargetBase.
57 //
58 // IMPORTANT: Every leaf derived class must call Detach() at the top of its
59 // destructor. (See documentation on Detach() below.)
60 template <typename T>
61 class TargetBase : public Target<T> {
62  public:
63  ~TargetBase() override;
64 
65  typename Source<T>::AspiredVersionsCallback GetAspiredVersionsCallback()
66  final;
67 
68  protected:
69  // This is an abstract class.
70  TargetBase();
71 
72  // A method supplied by the implementing subclass to handle incoming aspired-
73  // versions requests from sources.
74  //
75  // IMPORTANT: The SetAspiredVersions() implementation must be thread-safe, to
76  // handle the case of multiple sources (or one multi-threaded source).
77  //
78  // May block until the target has been fully set up and is able to handle the
79  // incoming request.
80  virtual void SetAspiredVersions(const StringPiece servable_name,
81  std::vector<ServableData<T>> versions) = 0;
82 
83  // Stops receiving SetAspiredVersions() calls. Every leaf derived class (i.e.
84  // sub-sub-...-class with no children) must call Detach() at the top of its
85  // destructor to avoid races with state destruction. After Detach() returns,
86  // it is guaranteed that no SetAspiredVersions() calls are running (in any
87  // thread) and no new ones can run. Detach() must be called exactly once.
88  void Detach();
89 
90  private:
91  // Used to synchronize all class state. The shared pointer permits use in an
92  // observer lambda while being impervious to this class's destruction.
93  mutable std::shared_ptr<mutex> mu_;
94 
95  // Notified when Detach() has been called. The shared pointer permits use in
96  // an observer lambda while being impervious to this class's destruction.
97  std::shared_ptr<Notification> detached_;
98 
99  // An observer object that forwards to SetAspiredVersions(), if not detached.
100  std::unique_ptr<Observer<const StringPiece, std::vector<ServableData<T>>>>
101  observer_;
102 };
103 
104 // Connects a source to a target, s.t. the target will receive the source's
105 // aspired-versions requests.
106 template <typename T>
107 void ConnectSourceToTarget(Source<T>* source, Target<T>* target);
108 
110 // Implementation details follow. API users need not read.
111 
112 template <typename T>
113 TargetBase<T>::TargetBase() : mu_(new mutex), detached_(new Notification) {
114  std::shared_ptr<mutex> mu = mu_;
115  std::shared_ptr<Notification> detached = detached_;
116  observer_.reset(new Observer<const StringPiece, std::vector<ServableData<T>>>(
117  [mu, detached, this](const StringPiece servable_name,
118  std::vector<ServableData<T>> versions) {
119  mutex_lock l(*mu);
120  if (detached->HasBeenNotified()) {
121  // We're detached. Perform a no-op.
122  return;
123  }
124  this->SetAspiredVersions(servable_name, std::move(versions));
125  }));
126 }
127 
128 template <typename T>
129 TargetBase<T>::~TargetBase() {
130  DCHECK(detached_->HasBeenNotified()) << "Detach() must be called exactly "
131  "once, at the top of the leaf "
132  "derived class's destructor";
133 }
134 
135 template <typename T>
137 TargetBase<T>::GetAspiredVersionsCallback() {
138  mutex_lock l(*mu_);
139  if (detached_->HasBeenNotified()) {
140  // We're detached. Return a no-op callback.
141  return [](const StringPiece, std::vector<ServableData<T>>) {};
142  }
143  return observer_->Notifier();
144 }
145 
146 template <typename T>
147 void TargetBase<T>::Detach() {
148  DCHECK(!detached_->HasBeenNotified()) << "Detach() must be called exactly "
149  "once, at the top of the leaf "
150  "derived class's destructor";
151 
152  // We defer deleting the observer until after we've released the lock, to
153  // avoid a deadlock with the observer's internal lock when it calls our
154  // lambda.
155  std::unique_ptr<Observer<const StringPiece, std::vector<ServableData<T>>>>
156  detached_observer;
157  {
158  mutex_lock l(*mu_);
159  detached_observer = std::move(observer_);
160  if (!detached_->HasBeenNotified()) {
161  detached_->Notify();
162  }
163  }
164 }
165 
166 template <typename T>
167 void ConnectSourceToTarget(Source<T>* source, Target<T>* target) {
168  source->SetAspiredVersionsCallback(target->GetAspiredVersionsCallback());
169 }
170 
171 } // namespace serving
172 } // namespace tensorflow
173 
174 #endif // TENSORFLOW_SERVING_CORE_TARGET_H_
std::function< void(const StringPiece servable_name, std::vector< ServableData< T > > versions)> AspiredVersionsCallback
Definition: source.h:88