16 #include "tensorflow_serving/core/servable_state_monitor.h"
22 #include "absl/time/time.h"
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow_serving/core/servable_state.h"
27 namespace tensorflow {
31 void EraseLiveStatesEntry(
32 const ServableStateMonitor::ServableStateAndTime& state_and_time,
33 ServableStateMonitor::ServableMap*
const live_states) {
34 const string& servable_name = state_and_time.state.id.name;
35 const int64_t version = state_and_time.state.id.version;
36 auto servable_map_it = live_states->find(servable_name);
37 if (servable_map_it == live_states->end()) {
40 auto& version_map = servable_map_it->second;
41 auto version_map_it = version_map.find(version);
42 if (version_map_it == version_map.end()) {
46 version_map.erase(version_map_it);
47 if (version_map.empty()) {
48 live_states->erase(servable_map_it);
52 void UpdateLiveStates(
53 const ServableStateMonitor::ServableStateAndTime& state_and_time,
54 ServableStateMonitor::ServableMap*
const live_states) {
55 const string& servable_name = state_and_time.state.id.name;
56 const int64_t version = state_and_time.state.id.version;
57 if (state_and_time.state.manager_state != ServableState::ManagerState::kEnd) {
58 (*live_states)[servable_name][version] = state_and_time;
60 EraseLiveStatesEntry(state_and_time, live_states);
66 absl::optional<ServableState::ManagerState> HasSpecificServableReachedState(
67 const ServableId& servable_id,
const ServableState::ManagerState goal_state,
68 const absl::optional<ServableStateMonitor::ServableStateAndTime>
69 opt_servable_state_time) {
70 if (!opt_servable_state_time) {
73 const ServableState::ManagerState state =
74 opt_servable_state_time->state.manager_state;
75 if (state != goal_state && state != ServableState::ManagerState::kEnd) {
83 absl::optional<ServableId> HasAnyServableInStreamReachedState(
84 const string& stream_name,
const ServableState::ManagerState goal_state,
85 const ServableStateMonitor::ServableMap& states) {
86 absl::optional<ServableId> opt_servable_id;
87 const auto found_it = states.find(stream_name);
88 if (found_it == states.end()) {
91 const ServableStateMonitor::VersionMap& version_map = found_it->second;
92 for (
const auto& version_and_state_time : version_map) {
93 const ServableStateMonitor::ServableStateAndTime& state_and_time =
94 version_and_state_time.second;
95 if (state_and_time.state.manager_state == goal_state ||
96 state_and_time.state.manager_state ==
97 ServableState::ManagerState::kEnd) {
98 return {version_and_state_time.second.state.id};
107 return strings::StrCat(
"state: {",
state.DebugString(),
113 : options_(options) {
119 this->HandleEvent(state_and_time);
123 ServableStateMonitor::~ServableStateMonitor() {
126 bus_subscription_ =
nullptr;
129 absl::optional<ServableStateMonitor::ServableStateAndTime>
130 ServableStateMonitor::GetStateAndTimeInternal(
131 const ServableId& servable_id)
const {
132 auto it = states_.find(servable_id.name);
133 if (it == states_.end()) {
134 return absl::nullopt;
136 const VersionMap& versions = it->second;
137 auto it2 = versions.find(servable_id.version);
138 if (it2 == versions.end()) {
139 return absl::nullopt;
144 absl::optional<ServableStateMonitor::ServableStateAndTime>
147 return GetStateAndTimeInternal(servable_id);
152 const absl::optional<ServableStateAndTime>& state_and_time =
154 if (!state_and_time) {
155 return absl::nullopt;
157 return state_and_time->state;
161 const string& servable_name)
const {
163 auto it = states_.find(servable_name);
164 if (it == states_.end()) {
185 for (
auto& state : states_) {
186 std::vector<Version> versions_to_forget;
187 auto& version_map = state.second;
188 for (
const auto& version : version_map) {
189 if (version.second.state.manager_state ==
190 ServableState::ManagerState::kEnd) {
191 versions_to_forget.emplace_back(version.first);
194 for (
const auto& version : versions_to_forget) {
195 version_map.erase(version);
200 ServableStateMonitor::ServableSet
201 ServableStateMonitor::GetAvailableServableStates()
const {
202 ServableSet available_servable_set;
204 for (
const auto& live_state : live_states_) {
205 const string& servable_name = live_state.first;
206 const auto& version_map = live_state.second;
207 for (
const auto& version : version_map) {
208 const ServableStateAndTime state_and_time = version.second;
209 if (state_and_time.state.manager_state ==
210 ServableState::ManagerState::kAvailable) {
211 available_servable_set.insert(servable_name);
215 return available_servable_set;
223 void ServableStateMonitor::NotifyWhenServablesReachState(
224 const std::vector<ServableRequest>& servables,
225 const ServableState::ManagerState goal_state,
228 servable_state_notification_requests_.push_back(
229 {servables, goal_state, notifier_fn});
230 MaybeSendStateReachedNotifications();
233 void ServableStateMonitor::Notify(
const NotifyFn& notify_fn) {
234 mutex_lock l(notify_mu_);
235 notify_fns_.push_back(notify_fn);
239 const std::vector<ServableRequest>& servables,
240 const ServableState::ManagerState goal_state, absl::Duration timeout,
241 std::map<ServableId, ServableState::ManagerState>*
const states_reached) {
242 bool reached_goal_state =
false;
243 Notification notified;
244 NotifyWhenServablesReachState(
245 servables, goal_state,
246 [&](
const bool incoming_reached_goal_state,
247 const std::map<ServableId, ServableState::ManagerState>&
248 incoming_states_reached) {
249 if (states_reached !=
nullptr) {
250 *states_reached = incoming_states_reached;
252 reached_goal_state = incoming_reached_goal_state;
255 notified.WaitForNotificationWithTimeout(timeout);
256 return reached_goal_state;
259 bool ServableStateMonitor::WaitUntilServablesReachState(
260 const std::vector<ServableRequest>& servables,
261 const ServableState::ManagerState goal_state,
262 std::map<ServableId, ServableState::ManagerState>*
const states_reached) {
264 servables, goal_state,
265 absl::InfiniteDuration(), states_reached);
268 void ServableStateMonitor::PreHandleEvent(
269 const EventBus<ServableState>::EventAndTime& state_and_time) {}
271 void ServableStateMonitor::HandleEvent(
272 const EventBus<ServableState>::EventAndTime& event_and_time) {
273 PreHandleEvent(event_and_time);
276 gtl::MakeCleanup([&]() { SendNotifications(event_and_time.event); });
279 const ServableStateAndTime state_and_time = {
280 event_and_time.event, event_and_time.event_time_micros};
281 states_[state_and_time.state.id.name][state_and_time.state.id.version] =
283 UpdateLiveStates(state_and_time, &live_states_);
284 MaybeSendStateReachedNotifications();
292 log_.emplace_back(state_and_time.state, state_and_time.event_time_micros);
296 std::pair<bool, std::map<ServableId, ServableState::ManagerState>>>
297 ServableStateMonitor::ShouldSendStateReachedNotification(
298 const ServableStateNotificationRequest& notification_request) {
299 bool reached_goal_state =
true;
300 std::map<ServableId, ServableState::ManagerState> states_reached;
301 for (
const auto& servable_request : notification_request.servables) {
302 if (servable_request.version) {
303 const ServableId servable_id = {servable_request.name,
304 *servable_request.version};
305 const absl::optional<ServableState::ManagerState> opt_state =
306 HasSpecificServableReachedState(servable_id,
307 notification_request.goal_state,
308 GetStateAndTimeInternal(servable_id));
314 reached_goal_state && *opt_state == notification_request.goal_state;
315 states_reached[servable_id] = *opt_state;
317 const absl::optional<ServableId> opt_servable_id =
318 HasAnyServableInStreamReachedState(
319 servable_request.name, notification_request.goal_state, states_);
320 if (!opt_servable_id) {
323 const ServableState::ManagerState reached_state =
324 GetStateAndTimeInternal(*opt_servable_id)->state.manager_state;
326 reached_goal_state = reached_goal_state &&
327 reached_state == notification_request.goal_state;
328 states_reached[*opt_servable_id] = reached_state;
331 return {{reached_goal_state, states_reached}};
334 void ServableStateMonitor::MaybeSendStateReachedNotifications() {
335 for (
auto iter = servable_state_notification_requests_.begin();
336 iter != servable_state_notification_requests_.end();) {
337 const ServableStateNotificationRequest& notification_request = *iter;
338 const absl::optional<
339 std::pair<bool, std::map<ServableId, ServableState::ManagerState>>>
340 opt_state_and_states_reached =
341 ShouldSendStateReachedNotification(notification_request);
342 if (opt_state_and_states_reached) {
343 notification_request.notifier_fn(opt_state_and_states_reached->first,
344 opt_state_and_states_reached->second);
345 iter = servable_state_notification_requests_.erase(iter);
352 void ServableStateMonitor::SendNotifications(
353 const ServableState& servable_state) {
354 mutex_lock l(notify_mu_);
355 for (
const auto& notify_fn : notify_fns_) {
356 notify_fn(servable_state);
std::unique_ptr< Subscription > Subscribe(const Callback &callback) TF_LOCKS_EXCLUDED(mutex_) TF_MUST_USE_RESULT
BoundedLog GetBoundedLog() const TF_LOCKS_EXCLUDED(mu_)
Returns the current bounded log of handled servable state events.
VersionMap GetVersionStates(const string &servable_name) const TF_LOCKS_EXCLUDED(mu_)
ServableMap GetLiveServableStates() const TF_LOCKS_EXCLUDED(mu_)
std::function< void(bool reached_goal_state, const std::map< ServableId, ServableState::ManagerState > &states_reached)> ServableStateNotifierFn
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)
absl::optional< ServableStateAndTime > GetStateAndTime(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)
bool WaitUntilServablesReachStateWithTimeout(const std::vector< ServableRequest > &servables, ServableState::ManagerState goal_state, absl::Duration timeout, std::map< ServableId, ServableState::ManagerState > *states_reached=nullptr) TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT
ServableMap GetAllServableStates() const TF_LOCKS_EXCLUDED(mu_)
Returns the current states of all tracked versions of all servables.
void ForgetUnloadedServableStates() TF_LOCKS_EXCLUDED(mu_)
Event and the publish time associated with it.
uint64_t max_count_log_events
string DebugString() const
ServableState state
State of the servable.
uint64_t event_time_micros
Time at which servable state event was published.