TensorFlow Serving C++ API Documentation
servable_state_monitor_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/servable_state_monitor.h"
17 
18 #include <map>
19 #include <memory>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/notification.h"
29 #include "tensorflow_serving/core/manager.h"
30 #include "tensorflow_serving/core/servable_id.h"
31 #include "tensorflow_serving/core/servable_state.h"
32 #include "tensorflow_serving/util/event_bus.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 namespace {
37 
38 using ::testing::ElementsAre;
39 using ::testing::IsEmpty;
40 using ::testing::Pair;
41 using ::testing::UnorderedElementsAre;
42 using ServableStateAndTime = ServableStateMonitor::ServableStateAndTime;
43 
44 class ServableStateMonitorTest : public ::testing::Test {
45  protected:
46  ServableStateMonitorTest() {
47  env_ = std::make_unique<test_util::FakeClockEnv>(Env::Default());
48  EventBus<ServableState>::Options bus_options;
49  bus_options.env = env_.get();
50  bus_ = EventBus<ServableState>::CreateEventBus(bus_options);
51  }
52  void CreateMonitor(int max_count_log_events = 0) {
53  ServableStateMonitor::Options monitor_options;
54  monitor_options.max_count_log_events = max_count_log_events;
55  monitor_ =
56  std::make_unique<ServableStateMonitor>(bus_.get(), monitor_options);
57  }
58  std::unique_ptr<test_util::FakeClockEnv> env_;
59  std::shared_ptr<EventBus<ServableState>> bus_;
60  std::unique_ptr<ServableStateMonitor> monitor_;
61 };
62 
63 TEST_F(ServableStateMonitorTest, AddingStates) {
64  CreateMonitor(/*max_count_log_events=*/4);
65  ServableState notified_state;
66  monitor_->Notify([&](const ServableState& servable_state) {
67  notified_state = servable_state;
68  });
69  EXPECT_FALSE(monitor_->GetState(ServableId{"foo", 42}));
70  EXPECT_TRUE(monitor_->GetVersionStates("foo").empty());
71  EXPECT_TRUE(monitor_->GetAllServableStates().empty());
72  EXPECT_TRUE(monitor_->GetBoundedLog().empty());
73 
74  // Initial servable.
75  const ServableState state_0 = {
76  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
77  env_->AdvanceByMicroseconds(1);
78  const ServableStateAndTime state_0_and_time = {state_0, 1};
79  bus_->Publish(state_0);
80  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 42}));
81  EXPECT_EQ(state_0, *monitor_->GetState(ServableId{"foo", 42}));
82  EXPECT_EQ(state_0, notified_state);
83  EXPECT_FALSE(monitor_->GetState(ServableId{"foo", 99}));
84  EXPECT_FALSE(monitor_->GetState(ServableId{"bar", 42}));
85  EXPECT_THAT(monitor_->GetVersionStates("foo"),
86  ElementsAre(Pair(42, state_0_and_time)));
87  EXPECT_TRUE(monitor_->GetVersionStates("bar").empty());
88  EXPECT_THAT(monitor_->GetAllServableStates(),
89  UnorderedElementsAre(
90  Pair("foo", ElementsAre(Pair(42, state_0_and_time)))));
91  EXPECT_THAT(monitor_->GetBoundedLog(), ElementsAre(state_0_and_time));
92 
93  // New version of existing servable.
94  const ServableState state_1 = {ServableId{"foo", 43},
95  ServableState::ManagerState::kAvailable,
96  errors::Unknown("error")};
97  env_->AdvanceByMicroseconds(2);
98  const ServableStateAndTime state_1_and_time = {state_1, 3};
99  bus_->Publish(state_1);
100  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 42}));
101  EXPECT_EQ(state_0, *monitor_->GetState(ServableId{"foo", 42}));
102  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 43}));
103  EXPECT_EQ(state_1, *monitor_->GetState(ServableId{"foo", 43}));
104  EXPECT_EQ(state_1, notified_state);
105  EXPECT_FALSE(monitor_->GetState(ServableId{"foo", 99}));
106  EXPECT_FALSE(monitor_->GetState(ServableId{"bar", 42}));
107  EXPECT_THAT(
108  monitor_->GetVersionStates("foo"),
109  ElementsAre(Pair(43, state_1_and_time), Pair(42, state_0_and_time)));
110  EXPECT_TRUE(monitor_->GetVersionStates("bar").empty());
111  EXPECT_THAT(monitor_->GetAllServableStates(),
112  UnorderedElementsAre(
113  Pair("foo", ElementsAre(Pair(43, state_1_and_time),
114  Pair(42, state_0_and_time)))));
115  EXPECT_THAT(monitor_->GetBoundedLog(),
116  ElementsAre(state_0_and_time, state_1_and_time));
117 
118  // New servable name.
119  const ServableState state_2 = {ServableId{"bar", 7},
120  ServableState::ManagerState::kUnloading,
121  OkStatus()};
122  env_->AdvanceByMicroseconds(4);
123  const ServableStateAndTime state_2_and_time = {state_2, 7};
124  bus_->Publish(state_2);
125  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 42}));
126  EXPECT_EQ(state_0, *monitor_->GetState(ServableId{"foo", 42}));
127  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 43}));
128  EXPECT_EQ(state_1, *monitor_->GetState(ServableId{"foo", 43}));
129  ASSERT_TRUE(monitor_->GetState(ServableId{"bar", 7}));
130  EXPECT_EQ(state_2, *monitor_->GetState(ServableId{"bar", 7}));
131  EXPECT_EQ(state_2, notified_state);
132  EXPECT_FALSE(monitor_->GetState(ServableId{"bar", 42}));
133  EXPECT_THAT(
134  monitor_->GetVersionStates("foo"),
135  ElementsAre(Pair(43, state_1_and_time), Pair(42, state_0_and_time)));
136  EXPECT_THAT(monitor_->GetVersionStates("bar"),
137  ElementsAre(Pair(7, state_2_and_time)));
138  EXPECT_TRUE(monitor_->GetVersionStates("baz").empty());
139  EXPECT_THAT(monitor_->GetAllServableStates(),
140  UnorderedElementsAre(
141  Pair("foo", ElementsAre(Pair(43, state_1_and_time),
142  Pair(42, state_0_and_time))),
143  Pair("bar", ElementsAre(Pair(7, state_2_and_time)))));
144 
145  EXPECT_THAT(
146  monitor_->GetBoundedLog(),
147  ElementsAre(state_0_and_time, state_1_and_time, state_2_and_time));
148 }
149 
150 TEST_F(ServableStateMonitorTest, UpdatingStates) {
151  CreateMonitor(/*max_count_log_events=*/3);
152 
153  // Initial servables.
154  const ServableState state_0 = {
155  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
156  env_->AdvanceByMicroseconds(4);
157  const ServableStateAndTime state_0_and_time = {state_0, 4};
158  bus_->Publish(state_0);
159  const ServableState state_1 = {ServableId{"foo", 43},
160  ServableState::ManagerState::kAvailable,
161  errors::Unknown("error")};
162  const ServableStateAndTime state_1_and_time = {state_1, 4};
163  bus_->Publish(state_1);
164  const ServableState state_2 = {ServableId{"bar", 7},
165  ServableState::ManagerState::kUnloading,
166  OkStatus()};
167  const ServableStateAndTime state_2_and_time = {state_2, 4};
168  bus_->Publish(state_2);
169  EXPECT_THAT(monitor_->GetAllServableStates(),
170  UnorderedElementsAre(
171  Pair("foo", ElementsAre(Pair(43, state_1_and_time),
172  Pair(42, state_0_and_time))),
173  Pair("bar", ElementsAre(Pair(7, state_2_and_time)))));
174  EXPECT_THAT(
175  monitor_->GetBoundedLog(),
176  ElementsAre(state_0_and_time, state_1_and_time, state_2_and_time));
177 
178  // Update one of them.
179  const ServableState state_1_updated = {
180  ServableId{"foo", 43}, ServableState::ManagerState::kLoading, OkStatus()};
181  env_->AdvanceByMicroseconds(4);
182  const ServableStateAndTime state_1_updated_and_time = {state_1_updated, 8};
183  bus_->Publish(state_1_updated);
184  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 42}));
185  EXPECT_EQ(state_0, *monitor_->GetState(ServableId{"foo", 42}));
186  ASSERT_TRUE(monitor_->GetState(ServableId{"foo", 43}));
187  EXPECT_EQ(state_1_updated, *monitor_->GetState(ServableId{"foo", 43}));
188  ASSERT_TRUE(monitor_->GetState(ServableId{"bar", 7}));
189  EXPECT_EQ(state_2, *monitor_->GetState(ServableId{"bar", 7}));
190  EXPECT_THAT(monitor_->GetVersionStates("foo"),
191  ElementsAre(Pair(43, state_1_updated_and_time),
192  Pair(42, state_0_and_time)));
193  EXPECT_THAT(monitor_->GetVersionStates("bar"),
194  ElementsAre(Pair(7, state_2_and_time)));
195  EXPECT_THAT(monitor_->GetAllServableStates(),
196  UnorderedElementsAre(
197  Pair("foo", ElementsAre(Pair(43, state_1_updated_and_time),
198  Pair(42, state_0_and_time))),
199  Pair("bar", ElementsAre(Pair(7, state_2_and_time)))));
200 
201  // The max count for events logged in the bounded log is 3, so the first entry
202  // corresponding to state_0 is removed and an entry is added for
203  // state_1_updated.
204  EXPECT_THAT(monitor_->GetBoundedLog(),
205  ElementsAre(state_1_and_time, state_2_and_time,
206  state_1_updated_and_time));
207 }
208 
209 TEST_F(ServableStateMonitorTest, DisableBoundedLogging) {
210  // The default value for max_count_log_events in options is 0, which disables
211  // logging.
212  CreateMonitor();
213  const ServableState state_0 = {
214  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
215  env_->AdvanceByMicroseconds(1);
216  const ServableStateAndTime state_0_and_time = {state_0, 1};
217  bus_->Publish(state_0);
218  EXPECT_THAT(monitor_->GetAllServableStates(),
219  UnorderedElementsAre(
220  Pair("foo", ElementsAre(Pair(42, state_0_and_time)))));
221  EXPECT_TRUE(monitor_->GetBoundedLog().empty());
222 }
223 
224 TEST_F(ServableStateMonitorTest, GetLiveServableStates) {
225  CreateMonitor();
226 
227  const ServableState state_0 = {
228  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
229  env_->AdvanceByMicroseconds(1);
230  const ServableStateAndTime state_0_and_time = {state_0, 1};
231  bus_->Publish(state_0);
232  EXPECT_THAT(monitor_->GetLiveServableStates(),
233  UnorderedElementsAre(
234  Pair("foo", ElementsAre(Pair(42, state_0_and_time)))));
235 
236  const ServableState state_1 = {ServableId{"bar", 7},
237  ServableState::ManagerState::kAvailable,
238  OkStatus()};
239  env_->AdvanceByMicroseconds(1);
240  const ServableStateAndTime state_1_and_time = {state_1, 2};
241  bus_->Publish(state_1);
242  EXPECT_THAT(monitor_->GetLiveServableStates(),
243  UnorderedElementsAre(
244  Pair("foo", ElementsAre(Pair(42, state_0_and_time))),
245  Pair("bar", ElementsAre(Pair(7, state_1_and_time)))));
246 
247  // Servable {foo, 42} moves to state kEnd and is removed from the live states
248  // servables.
249  const ServableState state_0_update = {
250  ServableId{"foo", 42}, ServableState::ManagerState::kEnd, OkStatus()};
251  env_->AdvanceByMicroseconds(1);
252  bus_->Publish(state_0_update);
253  EXPECT_THAT(monitor_->GetLiveServableStates(),
254  UnorderedElementsAre(
255  Pair("bar", ElementsAre(Pair(7, state_1_and_time)))));
256 }
257 
258 TEST_F(ServableStateMonitorTest, GetAvailableServableStates) {
259  CreateMonitor();
260 
261  const ServableState state_0 = {
262  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
263  env_->AdvanceByMicroseconds(1);
264  const ServableStateAndTime state_0_and_time = {state_0, 1};
265  bus_->Publish(state_0);
266  EXPECT_THAT(monitor_->GetAvailableServableStates(), testing::IsEmpty());
267 
268  env_->AdvanceByMicroseconds(1);
269  std::vector<ServableStateAndTime> servable_state_and_time;
270  for (const auto& servable_id : {ServableId{"bar", 6}, ServableId{"bar", 7}}) {
271  const ServableState state = {
272  servable_id, ServableState::ManagerState::kAvailable, OkStatus()};
273  const ServableStateAndTime state_and_time = {state, 2};
274  servable_state_and_time.push_back({state, 2});
275  bus_->Publish(state);
276  }
277 
278  EXPECT_THAT(monitor_->GetAvailableServableStates(),
279  UnorderedElementsAre("bar"));
280 
281  // Servable {bar, 6} moves to state kUnloading and is removed from available
282  // servable states.
283  const ServableState state_0_update = {ServableId{"bar", 6},
284  ServableState::ManagerState::kUnloading,
285  OkStatus()};
286  env_->AdvanceByMicroseconds(1);
287  bus_->Publish(state_0_update);
288  EXPECT_THAT(monitor_->GetAvailableServableStates(),
289  UnorderedElementsAre("bar"));
290  // Servable {bar, 7} moves to state kEnd and is removed from available
291  // servable states.
292  const ServableState state_1_update = {
293  ServableId{"bar", 7}, ServableState::ManagerState::kEnd, OkStatus()};
294  env_->AdvanceByMicroseconds(1);
295  bus_->Publish(state_1_update);
296  // No available state now.
297  EXPECT_THAT(monitor_->GetAvailableServableStates(), ::testing::IsEmpty());
298 }
299 
300 TEST_F(ServableStateMonitorTest, VersionMapDescendingOrder) {
301  CreateMonitor();
302 
303  const ServableState state_0 = {
304  ServableId{"foo", 42}, ServableState::ManagerState::kStart, OkStatus()};
305  env_->AdvanceByMicroseconds(1);
306  const ServableStateAndTime state_0_and_time = {state_0, 1};
307  bus_->Publish(state_0);
308  EXPECT_THAT(monitor_->GetLiveServableStates(),
309  UnorderedElementsAre(
310  Pair("foo", ElementsAre(Pair(42, state_0_and_time)))));
311 
312  const ServableState state_1 = {ServableId{"foo", 7},
313  ServableState::ManagerState::kAvailable,
314  OkStatus()};
315  env_->AdvanceByMicroseconds(1);
316  const ServableStateAndTime state_1_and_time = {state_1, 2};
317  bus_->Publish(state_1);
318  EXPECT_THAT(monitor_->GetLiveServableStates(),
319  ElementsAre(Pair("foo", ElementsAre(Pair(42, state_0_and_time),
320  Pair(7, state_1_and_time)))));
321 }
322 
323 TEST_F(ServableStateMonitorTest, ForgetUnloadedServableStates) {
324  CreateMonitor();
325 
326  const ServableState state_0 = {ServableId{"foo", 42},
327  ServableState::ManagerState::kAvailable,
328  OkStatus()};
329  env_->AdvanceByMicroseconds(1);
330  const ServableStateAndTime state_0_and_time = {state_0, 1};
331  bus_->Publish(state_0);
332  EXPECT_THAT(monitor_->GetLiveServableStates(),
333  UnorderedElementsAre(
334  Pair("foo", ElementsAre(Pair(42, state_0_and_time)))));
335 
336  const ServableState state_1 = {ServableId{"bar", 1},
337  ServableState::ManagerState::kAvailable,
338  OkStatus()};
339  env_->AdvanceByMicroseconds(1);
340  const ServableStateAndTime state_1_and_time = {state_1, 2};
341  bus_->Publish(state_1);
342  EXPECT_THAT(monitor_->GetLiveServableStates(),
343  UnorderedElementsAre(
344  Pair("foo", ElementsAre(Pair(42, state_0_and_time))),
345  Pair("bar", ElementsAre(Pair(1, state_1_and_time)))));
346 
347  const ServableState state_2 = {ServableId{"foo", 42},
348  ServableState::ManagerState::kUnloading,
349  OkStatus()};
350  env_->AdvanceByMicroseconds(1);
351  const ServableStateAndTime state_2_and_time = {state_2, 3};
352  bus_->Publish(state_2);
353  monitor_->ForgetUnloadedServableStates();
354  // "foo" state should still be recorded since it hasn't reached kEnd.
355  EXPECT_THAT(monitor_->GetAllServableStates(),
356  UnorderedElementsAre(
357  Pair("foo", ElementsAre(Pair(42, state_2_and_time))),
358  Pair("bar", ElementsAre(Pair(1, state_1_and_time)))));
359 
360  const ServableState state_3 = {ServableId{"foo", 42},
361  ServableState::ManagerState::kEnd, OkStatus()};
362  env_->AdvanceByMicroseconds(1);
363  const ServableStateAndTime state_3_and_time = {state_3, 4};
364  bus_->Publish(state_3);
365  EXPECT_THAT(monitor_->GetAllServableStates(),
366  UnorderedElementsAre(
367  Pair("foo", ElementsAre(Pair(42, state_3_and_time))),
368  Pair("bar", ElementsAre(Pair(1, state_1_and_time)))));
369  monitor_->ForgetUnloadedServableStates();
370  EXPECT_THAT(monitor_->GetAllServableStates(),
371  UnorderedElementsAre(
372  Pair("foo", IsEmpty()),
373  Pair("bar", ElementsAre(Pair(1, state_1_and_time)))));
374 }
375 
376 TEST_F(ServableStateMonitorTest, NotifyWhenServablesReachStateZeroServables) {
377  CreateMonitor();
378  const std::vector<ServableRequest> servables = {};
379 
380  using ManagerState = ServableState::ManagerState;
381 
382  Notification notified;
383  monitor_->NotifyWhenServablesReachState(
384  servables, ManagerState::kAvailable,
385  [&](const bool reached,
386  std::map<ServableId, ManagerState> states_reached) {
387  EXPECT_TRUE(reached);
388  EXPECT_THAT(states_reached, IsEmpty());
389  notified.Notify();
390  });
391  notified.WaitForNotification();
392 }
393 
394 TEST_F(ServableStateMonitorTest,
395  NotifyWhenServablesReachStateSpecificAvailable) {
396  CreateMonitor();
397  std::vector<ServableRequest> servables;
398  const ServableId specific_goal_state_id = {"specific_goal_state", 42};
399  servables.push_back(ServableRequest::FromId(specific_goal_state_id));
400 
401  using ManagerState = ServableState::ManagerState;
402  const ServableState specific_goal_state = {
403  specific_goal_state_id, ManagerState::kAvailable, OkStatus()};
404 
405  Notification notified;
406  monitor_->NotifyWhenServablesReachState(
407  servables, ManagerState::kAvailable,
408  [&](const bool reached,
409  std::map<ServableId, ManagerState> states_reached) {
410  EXPECT_TRUE(reached);
411  EXPECT_THAT(states_reached, UnorderedElementsAre(Pair(
412  ServableId{"specific_goal_state", 42},
413  ManagerState::kAvailable)));
414  notified.Notify();
415  });
416  bus_->Publish(specific_goal_state);
417  notified.WaitForNotification();
418 }
419 
420 TEST_F(ServableStateMonitorTest, NotifyWhenServablesReachStateSpecificError) {
421  CreateMonitor();
422  std::vector<ServableRequest> servables;
423  const ServableId specific_error_state_id = {"specific_error_state", 42};
424  servables.push_back(ServableRequest::FromId(specific_error_state_id));
425 
426  using ManagerState = ServableState::ManagerState;
427  const ServableState specific_error_state = {
428  specific_error_state_id, ManagerState::kEnd, errors::Internal("error")};
429 
430  Notification notified;
431  monitor_->NotifyWhenServablesReachState(
432  servables, ManagerState::kAvailable,
433  [&](const bool reached,
434  std::map<ServableId, ManagerState> states_reached) {
435  EXPECT_FALSE(reached);
436  EXPECT_THAT(states_reached,
437  UnorderedElementsAre(
438  Pair(specific_error_state_id, ManagerState::kEnd)));
439  notified.Notify();
440  });
441  bus_->Publish(specific_error_state);
442  notified.WaitForNotification();
443 }
444 
445 TEST_F(ServableStateMonitorTest,
446  NotifyWhenServablesReachStateServableLatestAvailable) {
447  CreateMonitor();
448  std::vector<ServableRequest> servables;
449  servables.push_back(ServableRequest::Latest("servable_stream"));
450  const ServableId servable_stream_available_state_id = {"servable_stream", 42};
451 
452  using ManagerState = ServableState::ManagerState;
453  const ServableState servable_stream_available_state = {
454  servable_stream_available_state_id, ManagerState::kAvailable, OkStatus()};
455 
456  Notification notified;
457  monitor_->NotifyWhenServablesReachState(
458  servables, ManagerState::kAvailable,
459  [&](const bool reached,
460  std::map<ServableId, ManagerState> states_reached) {
461  EXPECT_TRUE(reached);
462  EXPECT_THAT(states_reached, UnorderedElementsAre(
463  Pair(servable_stream_available_state_id,
464  ManagerState::kAvailable)));
465  notified.Notify();
466  });
467  bus_->Publish(servable_stream_available_state);
468  notified.WaitForNotification();
469 }
470 
471 TEST_F(ServableStateMonitorTest, NotifyWhenServablesReachStateLatestError) {
472  CreateMonitor();
473  std::vector<ServableRequest> servables;
474  servables.push_back(ServableRequest::Latest("servable_stream"));
475  const ServableId servable_stream_error_state_id = {"servable_stream", 7};
476 
477  using ManagerState = ServableState::ManagerState;
478  const ServableState servable_stream_error_state = {
479  servable_stream_error_state_id, ManagerState::kEnd,
480  errors::Internal("error")};
481 
482  Notification notified;
483  monitor_->NotifyWhenServablesReachState(
484  servables, ManagerState::kAvailable,
485  [&](const bool reached,
486  std::map<ServableId, ManagerState> states_reached) {
487  EXPECT_FALSE(reached);
488  EXPECT_THAT(states_reached,
489  UnorderedElementsAre(Pair(servable_stream_error_state_id,
490  ManagerState::kEnd)));
491  notified.Notify();
492  });
493  bus_->Publish(servable_stream_error_state);
494  notified.WaitForNotification();
495 }
496 
497 TEST_F(ServableStateMonitorTest,
498  NotifyWhenServablesReachStateFullFunctionality) {
499  using ManagerState = ServableState::ManagerState;
500 
501  CreateMonitor();
502  std::vector<ServableRequest> servables;
503  const ServableId specific_goal_state_id = {"specific_goal_state", 42};
504  servables.push_back(ServableRequest::FromId(specific_goal_state_id));
505  const ServableId specific_error_state_id = {"specific_error_state", 42};
506  servables.push_back(ServableRequest::FromId(specific_error_state_id));
507  servables.push_back(ServableRequest::Latest("servable_stream"));
508  const ServableId servable_stream_id = {"servable_stream", 7};
509 
510  Notification notified;
511  monitor_->NotifyWhenServablesReachState(
512  servables, ManagerState::kAvailable,
513  [&](const bool reached,
514  std::map<ServableId, ManagerState> states_reached) {
515  EXPECT_FALSE(reached);
516  EXPECT_THAT(states_reached,
517  UnorderedElementsAre(
518  Pair(specific_goal_state_id, ManagerState::kAvailable),
519  Pair(specific_error_state_id, ManagerState::kEnd),
520  Pair(servable_stream_id, ManagerState::kAvailable)));
521  notified.Notify();
522  });
523 
524  const ServableState specific_goal_state = {
525  specific_goal_state_id, ManagerState::kAvailable, OkStatus()};
526  const ServableState specific_error_state = {
527  specific_error_state_id, ManagerState::kEnd, errors::Internal("error")};
528  const ServableState servable_stream_state = {
529  servable_stream_id, ManagerState::kAvailable, OkStatus()};
530 
531  bus_->Publish(specific_goal_state);
532  ASSERT_FALSE(notified.HasBeenNotified());
533  bus_->Publish(specific_error_state);
534  ASSERT_FALSE(notified.HasBeenNotified());
535  bus_->Publish(servable_stream_state);
536  notified.WaitForNotification();
537 }
538 
539 TEST_F(ServableStateMonitorTest,
540  NotifyWhenServablesReachStateOnlyNotifiedOnce) {
541  CreateMonitor();
542  std::vector<ServableRequest> servables;
543  const ServableId specific_goal_state_id = {"specific_goal_state", 42};
544  servables.push_back(ServableRequest::FromId(specific_goal_state_id));
545 
546  using ManagerState = ServableState::ManagerState;
547  const ServableState specific_goal_state = {
548  specific_goal_state_id, ManagerState::kAvailable, OkStatus()};
549 
550  Notification notified;
551  monitor_->NotifyWhenServablesReachState(
552  servables, ManagerState::kAvailable,
553  [&](const bool reached,
554  std::map<ServableId, ManagerState> states_reached) {
555  // Will fail if this function is called twice.
556  ASSERT_FALSE(notified.HasBeenNotified());
557  EXPECT_TRUE(reached);
558  EXPECT_THAT(states_reached, UnorderedElementsAre(Pair(
559  ServableId{"specific_goal_state", 42},
560  ManagerState::kAvailable)));
561  notified.Notify();
562  });
563  bus_->Publish(specific_goal_state);
564  notified.WaitForNotification();
565  bus_->Publish(specific_goal_state);
566 }
567 
568 TEST_F(ServableStateMonitorTest,
569  WaitUntilServablesReachStateFullFunctionality) {
570  using ManagerState = ServableState::ManagerState;
571 
572  CreateMonitor();
573  std::vector<ServableRequest> servables;
574  const ServableId specific_goal_state_id = {"specific_goal_state", 42};
575  servables.push_back(ServableRequest::FromId(specific_goal_state_id));
576  const ServableId specific_error_state_id = {"specific_error_state", 42};
577  servables.push_back(ServableRequest::FromId(specific_error_state_id));
578  servables.push_back(ServableRequest::Latest("servable_stream"));
579  const ServableId servable_stream_id = {"servable_stream", 7};
580 
581  const ServableState specific_goal_state = {
582  specific_goal_state_id, ManagerState::kAvailable, OkStatus()};
583  const ServableState specific_error_state = {
584  specific_error_state_id, ManagerState::kEnd, errors::Internal("error")};
585  const ServableState servable_stream_state = {
586  servable_stream_id, ManagerState::kAvailable, OkStatus()};
587 
588  bus_->Publish(specific_goal_state);
589  bus_->Publish(specific_error_state);
590 
591  std::map<ServableId, ManagerState> states_reached;
592  Notification waiting_done;
593  std::unique_ptr<Thread> wait_till_servable_state_reached(
594  Env::Default()->StartThread({}, "WaitUntilServablesReachState", [&]() {
595  EXPECT_FALSE(monitor_->WaitUntilServablesReachState(
596  servables, ManagerState::kAvailable, &states_reached));
597  EXPECT_THAT(states_reached,
598  UnorderedElementsAre(
599  Pair(specific_goal_state_id, ManagerState::kAvailable),
600  Pair(specific_error_state_id, ManagerState::kEnd),
601  Pair(servable_stream_id, ManagerState::kAvailable)));
602  waiting_done.Notify();
603  }));
604  // We publish till waiting is finished, otherwise we could publish before we
605  // could start waiting.
606  while (!waiting_done.HasBeenNotified()) {
607  bus_->Publish(servable_stream_state);
608  }
609 }
610 
611 } // namespace
612 } // namespace serving
613 } // namespace tensorflow
static std::shared_ptr< EventBus > CreateEventBus(const Options &options={})
Definition: event_bus.h:191