TensorFlow Serving C++ API Documentation
servable_state_monitor.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/servable_state_monitor.h"
17 
18 #include <map>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/time/time.h"
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow_serving/core/servable_state.h"
26 
27 namespace tensorflow {
28 namespace serving {
29 namespace {
30 
31 void EraseLiveStatesEntry(
32  const ServableStateMonitor::ServableStateAndTime& state_and_time,
33  ServableStateMonitor::ServableMap* const live_states) {
34  const string& servable_name = state_and_time.state.id.name;
35  const int64_t version = state_and_time.state.id.version;
36  auto servable_map_it = live_states->find(servable_name);
37  if (servable_map_it == live_states->end()) {
38  return;
39  }
40  auto& version_map = servable_map_it->second;
41  auto version_map_it = version_map.find(version);
42  if (version_map_it == version_map.end()) {
43  return;
44  }
45 
46  version_map.erase(version_map_it);
47  if (version_map.empty()) {
48  live_states->erase(servable_map_it);
49  }
50 }
51 
52 void UpdateLiveStates(
53  const ServableStateMonitor::ServableStateAndTime& state_and_time,
54  ServableStateMonitor::ServableMap* const live_states) {
55  const string& servable_name = state_and_time.state.id.name;
56  const int64_t version = state_and_time.state.id.version;
57  if (state_and_time.state.manager_state != ServableState::ManagerState::kEnd) {
58  (*live_states)[servable_name][version] = state_and_time;
59  } else {
60  EraseLiveStatesEntry(state_and_time, live_states);
61  }
62 }
63 
64 // Returns the state reached iff the servable has reached 'goal_state' or kEnd,
65 // otherwise nullopt.
66 absl::optional<ServableState::ManagerState> HasSpecificServableReachedState(
67  const ServableId& servable_id, const ServableState::ManagerState goal_state,
68  const absl::optional<ServableStateMonitor::ServableStateAndTime>
69  opt_servable_state_time) {
70  if (!opt_servable_state_time) {
71  return {};
72  }
73  const ServableState::ManagerState state =
74  opt_servable_state_time->state.manager_state;
75  if (state != goal_state && state != ServableState::ManagerState::kEnd) {
76  return {};
77  }
78  return {state};
79 }
80 
81 // Returns the id of the servable in the stream which has reached 'goal_state'
82 // or kEnd. If no servable has done so, returns nullopt.
83 absl::optional<ServableId> HasAnyServableInStreamReachedState(
84  const string& stream_name, const ServableState::ManagerState goal_state,
85  const ServableStateMonitor::ServableMap& states) {
86  absl::optional<ServableId> opt_servable_id;
87  const auto found_it = states.find(stream_name);
88  if (found_it == states.end()) {
89  return {};
90  }
91  const ServableStateMonitor::VersionMap& version_map = found_it->second;
92  for (const auto& version_and_state_time : version_map) {
93  const ServableStateMonitor::ServableStateAndTime& state_and_time =
94  version_and_state_time.second;
95  if (state_and_time.state.manager_state == goal_state ||
96  state_and_time.state.manager_state ==
97  ServableState::ManagerState::kEnd) {
98  return {version_and_state_time.second.state.id};
99  }
100  }
101  return {};
102 }
103 
104 } // namespace
105 
107  return strings::StrCat("state: {", state.DebugString(),
108  "}, event_time_micros: ", event_time_micros);
109 }
110 
111 ServableStateMonitor::ServableStateMonitor(EventBus<ServableState>* bus,
112  const Options& options)
113  : options_(options) {
114  // Important: We must allow the state members ('states_', 'live_states_' and
115  // so on) to be initialized *before* we start the bus subscription, in case an
116  // event comes in while we are initializing.
117  bus_subscription_ = bus->Subscribe(
118  [this](const EventBus<ServableState>::EventAndTime& state_and_time) {
119  this->HandleEvent(state_and_time);
120  });
121 }
122 
123 ServableStateMonitor::~ServableStateMonitor() {
124  // Halt event handling first, before tearing down state that event handling
125  // may access such as 'servable_state_notification_requests_'.
126  bus_subscription_ = nullptr;
127 }
128 
129 absl::optional<ServableStateMonitor::ServableStateAndTime>
130 ServableStateMonitor::GetStateAndTimeInternal(
131  const ServableId& servable_id) const {
132  auto it = states_.find(servable_id.name);
133  if (it == states_.end()) {
134  return absl::nullopt;
135  }
136  const VersionMap& versions = it->second;
137  auto it2 = versions.find(servable_id.version);
138  if (it2 == versions.end()) {
139  return absl::nullopt;
140  }
141  return it2->second;
142 }
143 
144 absl::optional<ServableStateMonitor::ServableStateAndTime>
146  mutex_lock l(mu_);
147  return GetStateAndTimeInternal(servable_id);
148 }
149 
150 absl::optional<ServableState> ServableStateMonitor::GetState(
151  const ServableId& servable_id) const {
152  const absl::optional<ServableStateAndTime>& state_and_time =
153  GetStateAndTime(servable_id);
154  if (!state_and_time) {
155  return absl::nullopt;
156  }
157  return state_and_time->state;
158 }
159 
160 ServableStateMonitor::VersionMap ServableStateMonitor::GetVersionStates(
161  const string& servable_name) const {
162  mutex_lock l(mu_);
163  auto it = states_.find(servable_name);
164  if (it == states_.end()) {
165  return {};
166  }
167  return it->second;
168 }
169 
170 ServableStateMonitor::ServableMap ServableStateMonitor::GetAllServableStates()
171  const {
172  mutex_lock l(mu_);
173  return states_;
174 }
175 
176 ServableStateMonitor::ServableMap ServableStateMonitor::GetLiveServableStates()
177  const {
178  mutex_lock l(mu_);
179  return live_states_;
180 }
181 
183  mutex_lock l(mu_);
184 
185  for (auto& state : states_) {
186  std::vector<Version> versions_to_forget;
187  auto& version_map = state.second;
188  for (const auto& version : version_map) {
189  if (version.second.state.manager_state ==
190  ServableState::ManagerState::kEnd) {
191  versions_to_forget.emplace_back(version.first);
192  }
193  }
194  for (const auto& version : versions_to_forget) {
195  version_map.erase(version);
196  }
197  }
198 }
199 
200 ServableStateMonitor::ServableSet
201 ServableStateMonitor::GetAvailableServableStates() const {
202  ServableSet available_servable_set;
203  mutex_lock l(mu_);
204  for (const auto& live_state : live_states_) {
205  const string& servable_name = live_state.first;
206  const auto& version_map = live_state.second;
207  for (const auto& version : version_map) {
208  const ServableStateAndTime state_and_time = version.second;
209  if (state_and_time.state.manager_state ==
210  ServableState::ManagerState::kAvailable) {
211  available_servable_set.insert(servable_name);
212  }
213  }
214  }
215  return available_servable_set;
216 }
217 
218 ServableStateMonitor::BoundedLog ServableStateMonitor::GetBoundedLog() const {
219  mutex_lock l(mu_);
220  return log_;
221 }
222 
223 void ServableStateMonitor::NotifyWhenServablesReachState(
224  const std::vector<ServableRequest>& servables,
225  const ServableState::ManagerState goal_state,
226  const ServableStateNotifierFn& notifier_fn) {
227  mutex_lock l(mu_);
228  servable_state_notification_requests_.push_back(
229  {servables, goal_state, notifier_fn});
230  MaybeSendStateReachedNotifications();
231 }
232 
233 void ServableStateMonitor::Notify(const NotifyFn& notify_fn) {
234  mutex_lock l(notify_mu_);
235  notify_fns_.push_back(notify_fn);
236 }
237 
239  const std::vector<ServableRequest>& servables,
240  const ServableState::ManagerState goal_state, absl::Duration timeout,
241  std::map<ServableId, ServableState::ManagerState>* const states_reached) {
242  bool reached_goal_state = false;
243  Notification notified;
244  NotifyWhenServablesReachState(
245  servables, goal_state,
246  [&](const bool incoming_reached_goal_state,
247  const std::map<ServableId, ServableState::ManagerState>&
248  incoming_states_reached) {
249  if (states_reached != nullptr) {
250  *states_reached = incoming_states_reached;
251  }
252  reached_goal_state = incoming_reached_goal_state;
253  notified.Notify();
254  });
255  notified.WaitForNotificationWithTimeout(timeout);
256  return reached_goal_state;
257 }
258 
259 bool ServableStateMonitor::WaitUntilServablesReachState(
260  const std::vector<ServableRequest>& servables,
261  const ServableState::ManagerState goal_state,
262  std::map<ServableId, ServableState::ManagerState>* const states_reached) {
264  servables, goal_state,
265  /*timeout=*/absl::InfiniteDuration(), states_reached);
266 }
267 
268 void ServableStateMonitor::PreHandleEvent(
269  const EventBus<ServableState>::EventAndTime& state_and_time) {}
270 
271 void ServableStateMonitor::HandleEvent(
272  const EventBus<ServableState>::EventAndTime& event_and_time) {
273  PreHandleEvent(event_and_time);
274 
275  auto cleanup =
276  gtl::MakeCleanup([&]() { SendNotifications(event_and_time.event); });
277 
278  mutex_lock l(mu_);
279  const ServableStateAndTime state_and_time = {
280  event_and_time.event, event_and_time.event_time_micros};
281  states_[state_and_time.state.id.name][state_and_time.state.id.version] =
282  state_and_time;
283  UpdateLiveStates(state_and_time, &live_states_);
284  MaybeSendStateReachedNotifications();
285 
286  if (options_.max_count_log_events == 0) {
287  return;
288  }
289  while (log_.size() >= options_.max_count_log_events) {
290  log_.pop_front();
291  }
292  log_.emplace_back(state_and_time.state, state_and_time.event_time_micros);
293 }
294 
295 absl::optional<
296  std::pair<bool, std::map<ServableId, ServableState::ManagerState>>>
297 ServableStateMonitor::ShouldSendStateReachedNotification(
298  const ServableStateNotificationRequest& notification_request) {
299  bool reached_goal_state = true;
300  std::map<ServableId, ServableState::ManagerState> states_reached;
301  for (const auto& servable_request : notification_request.servables) {
302  if (servable_request.version) {
303  const ServableId servable_id = {servable_request.name,
304  *servable_request.version};
305  const absl::optional<ServableState::ManagerState> opt_state =
306  HasSpecificServableReachedState(servable_id,
307  notification_request.goal_state,
308  GetStateAndTimeInternal(servable_id));
309  if (!opt_state) {
310  return {};
311  }
312  // Remains false once false.
313  reached_goal_state =
314  reached_goal_state && *opt_state == notification_request.goal_state;
315  states_reached[servable_id] = *opt_state;
316  } else {
317  const absl::optional<ServableId> opt_servable_id =
318  HasAnyServableInStreamReachedState(
319  servable_request.name, notification_request.goal_state, states_);
320  if (!opt_servable_id) {
321  return {};
322  }
323  const ServableState::ManagerState reached_state =
324  GetStateAndTimeInternal(*opt_servable_id)->state.manager_state;
325  // Remains false once false.
326  reached_goal_state = reached_goal_state &&
327  reached_state == notification_request.goal_state;
328  states_reached[*opt_servable_id] = reached_state;
329  }
330  }
331  return {{reached_goal_state, states_reached}};
332 }
333 
334 void ServableStateMonitor::MaybeSendStateReachedNotifications() {
335  for (auto iter = servable_state_notification_requests_.begin();
336  iter != servable_state_notification_requests_.end();) {
337  const ServableStateNotificationRequest& notification_request = *iter;
338  const absl::optional<
339  std::pair<bool, std::map<ServableId, ServableState::ManagerState>>>
340  opt_state_and_states_reached =
341  ShouldSendStateReachedNotification(notification_request);
342  if (opt_state_and_states_reached) {
343  notification_request.notifier_fn(opt_state_and_states_reached->first,
344  opt_state_and_states_reached->second);
345  iter = servable_state_notification_requests_.erase(iter);
346  } else {
347  ++iter;
348  }
349  }
350 }
351 
352 void ServableStateMonitor::SendNotifications(
353  const ServableState& servable_state) {
354  mutex_lock l(notify_mu_);
355  for (const auto& notify_fn : notify_fns_) {
356  notify_fn(servable_state);
357  }
358 }
359 
360 } // namespace serving
361 } // namespace tensorflow
std::unique_ptr< Subscription > Subscribe(const Callback &callback) TF_LOCKS_EXCLUDED(mutex_) TF_MUST_USE_RESULT
Definition: event_bus.h:178
BoundedLog GetBoundedLog() const TF_LOCKS_EXCLUDED(mu_)
Returns the current bounded log of handled servable state events.
VersionMap GetVersionStates(const string &servable_name) const TF_LOCKS_EXCLUDED(mu_)
ServableMap GetLiveServableStates() const TF_LOCKS_EXCLUDED(mu_)
std::function< void(bool reached_goal_state, const std::map< ServableId, ServableState::ManagerState > &states_reached)> ServableStateNotifierFn
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)
absl::optional< ServableStateAndTime > GetStateAndTime(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)
bool WaitUntilServablesReachStateWithTimeout(const std::vector< ServableRequest > &servables, ServableState::ManagerState goal_state, absl::Duration timeout, std::map< ServableId, ServableState::ManagerState > *states_reached=nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT
ServableMap GetAllServableStates() const TF_LOCKS_EXCLUDED(mu_)
Returns the current states of all tracked versions of all servables.
void ForgetUnloadedServableStates() TF_LOCKS_EXCLUDED(mu_)
Event and the publish time associated with it.
Definition: event_bus.h:103
uint64_t event_time_micros
Time at which servable state event was published.