TensorFlow Serving C++ API Documentation
aspired_versions_manager_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/aspired_versions_manager.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 #include "tensorflow_serving/core/aspired_version_policy.h"
36 #include "tensorflow_serving/core/availability_preserving_policy.h"
37 #include "tensorflow_serving/core/servable_state_monitor.h"
38 #include "tensorflow_serving/core/test_util/availability_test_util.h"
39 #include "tensorflow_serving/core/test_util/fake_loader.h"
40 #include "tensorflow_serving/core/test_util/manager_test_util.h"
41 #include "tensorflow_serving/core/test_util/mock_loader.h"
42 #include "tensorflow_serving/util/any_ptr.h"
43 #include "tensorflow_serving/util/event_bus.h"
44 
45 namespace tensorflow {
46 namespace serving {
47 namespace {
48 
49 using test_util::FakeLoader;
50 using test_util::WaitUntilServableManagerStateIsOneOf;
51 using ::testing::_;
52 using ::testing::Invoke;
53 using ::testing::InvokeWithoutArgs;
54 using ::testing::NiceMock;
55 using ::testing::Return;
56 using ::testing::UnorderedElementsAre;
57 using ::testing::UnorderedElementsAreArray;
58 
59 constexpr char kServableName[] = "kServableName";
60 constexpr char kServableName2[] = "kServableName2";
61 constexpr int kNumVersionsPerServable = 2;
62 constexpr int kNumTotalVersions = 4;
63 
64 // Creates an aspired-versions entry with 'id' and a FakeLoader whose servable
65 // is id.version.
66 ServableData<std::unique_ptr<Loader>> CreateAspiredVersion(
67  const ServableId& id) {
68  std::unique_ptr<Loader> loader(new FakeLoader(id.version));
69  return CreateServableData(id, std::move(loader));
70 }
71 
72 // We parameterize this test with the number of load & unload threads. (Zero
73 // means use an in-line executor instead of a thread pool.)
74 struct ThreadPoolSizes {
75  uint64_t num_load_threads;
76  uint64_t num_unload_threads;
77 };
78 class AspiredVersionsManagerTest
79  : public ::testing::TestWithParam<std::tuple<ThreadPoolSizes, bool>> {
80  protected:
81  AspiredVersionsManagerTest()
82  : servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
83  servable_state_monitor_(servable_event_bus_.get()),
84  thread_pool_sizes_(std::get<0>(GetParam())),
85  enable_reload_servables_with_error_(std::get<1>(GetParam())) {
86  AspiredVersionsManager::Options manager_options;
87  manager_options.num_load_threads = thread_pool_sizes_.num_load_threads;
88  manager_options.num_unload_threads = thread_pool_sizes_.num_unload_threads;
89  // The state manager thread won't be run automatically.
90  manager_options.manage_state_interval_micros = -1;
91  manager_options.env = Env::Default();
92  manager_options.aspired_version_policy.reset(
93  new AvailabilityPreservingPolicy());
94  manager_options.servable_event_bus = servable_event_bus_.get();
95  max_num_load_retries_ = 1;
96  manager_options.max_num_load_retries = max_num_load_retries_;
97  manager_options.load_retry_interval_micros = 0;
98  manager_options.enable_reload_servables_with_error =
99  enable_reload_servables_with_error_;
100  TF_CHECK_OK(
101  AspiredVersionsManager::Create(std::move(manager_options), &manager_));
102  }
103 
104  // Creates an aspired-versions entry with 'id' and an error (and no loader).
105  ServableData<std::unique_ptr<Loader>> CreateErroneousAspiredVersion(
106  const ServableId& id) {
107  return ServableData<std::unique_ptr<Loader>>(id, errors::Unknown("error"));
108  }
109 
110  void SetUp() override {
111  // We setUp the manager_ with two different servable streams, each with two
112  // aspired versions 0 and 1.
113  std::set<ServableId> servables;
114  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
115  for (int i = 0; i < kNumVersionsPerServable; ++i) {
116  const ServableId id = {kServableName, i};
117  aspired_versions.push_back(CreateAspiredVersion(id));
118  servables.insert(id);
119  }
120  manager_->GetAspiredVersionsCallback()(kServableName,
121  std::move(aspired_versions));
122  HandlePendingAspiredVersionsRequests();
123 
124  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions2;
125  for (int i = 0; i < kNumVersionsPerServable; ++i) {
126  const ServableId id = {kServableName2, i};
127  aspired_versions2.push_back(CreateAspiredVersion(id));
128  servables.insert(id);
129  }
130  manager_->GetAspiredVersionsCallback()(kServableName2,
131  std::move(aspired_versions2));
132  HandlePendingAspiredVersionsRequests();
133 
134  for (int i = 0; i < kNumTotalVersions; ++i) {
135  // Each time the state manager thread is run, we should load a servable
136  // version.
137  InvokePolicyAndExecuteAction();
138  }
139  for (const ServableId& servable : servables) {
140  WaitUntilServableManagerStateIsOneOf(
141  servable_state_monitor_, servable,
142  {ServableState::ManagerState::kAvailable});
143  }
144  }
145 
146  void FlushServables() {
147  test_util::AspiredVersionsManagerTestAccess(manager_.get())
148  .FlushServables();
149  }
150 
151  void HandlePendingAspiredVersionsRequests() {
152  test_util::AspiredVersionsManagerTestAccess(manager_.get())
153  .HandlePendingAspiredVersionsRequests();
154  }
155 
156  void InvokePolicyAndExecuteAction() {
157  test_util::AspiredVersionsManagerTestAccess(manager_.get())
158  .InvokePolicyAndExecuteAction();
159  }
160 
161  std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
162  ServableStateMonitor servable_state_monitor_;
163  ThreadPoolSizes thread_pool_sizes_;
164  uint32 max_num_load_retries_;
165  bool enable_reload_servables_with_error_;
166  std::unique_ptr<AspiredVersionsManager> manager_;
167 };
168 
169 INSTANTIATE_TEST_CASE_P(
170  WithOrWithoutThreadPools, AspiredVersionsManagerTest,
171  ::testing::Values(
172  // without load or unload threadpools
173  std::make_tuple(ThreadPoolSizes{0, 0}, false),
174  // with just a load threadpool
175  std::make_tuple(ThreadPoolSizes{2, 0}, false),
176  // with just an unload threadpool
177  std::make_tuple(ThreadPoolSizes{0, 2}, false),
178  // with load and unload threadpools
179  std::make_tuple(ThreadPoolSizes{4, 4}, false),
180  // without load or unload threadpools and retries of failed loads
181  std::make_tuple(ThreadPoolSizes{0, 0}, true),
182  // with load and unload threadpools and retries of failed loads
183  std::make_tuple(ThreadPoolSizes{4, 4}, true)));
184 
185 TEST_P(AspiredVersionsManagerTest, ServableHandleNotFoundMissingLoaderName) {
186  ServableHandle<int64_t> handle;
187  const Status status = manager_->GetServableHandle(
188  ServableRequest::Latest(strings::StrCat(kServableName, "missing")),
189  &handle);
190  ASSERT_FALSE(status.ok()) << status;
191  EXPECT_EQ(error::NOT_FOUND, status.code());
192 }
193 
194 TEST_P(AspiredVersionsManagerTest, ServableHandleNotFoundMissingVersion) {
195  // This version is missing.
196  const int64_t missing_version = 100;
197  ServableHandle<int64_t> handle;
198  const Status status = manager_->GetServableHandle(
199  ServableRequest::Specific(kServableName, missing_version), &handle);
200  ASSERT_FALSE(status.ok()) << status;
201  EXPECT_EQ(error::NOT_FOUND, status.code());
202 }
203 
204 TEST_P(AspiredVersionsManagerTest, ServableHandleInvalidArgument) {
205  // The servable is supposed to be an int type and we ask for a float type,
206  // thus causing an invalid argument error.
207  ServableHandle<float> handle;
208  const Status status = manager_->GetServableHandle(
209  ServableRequest::Latest(kServableName), &handle);
210  ASSERT_FALSE(status.ok()) << status;
211  EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
212 }
213 
214 TEST_P(AspiredVersionsManagerTest, ServableHandleLatest) {
215  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
216  const ServableId id = {kServableName, kNumVersionsPerServable + 1};
217  aspired_versions.push_back(CreateAspiredVersion(id));
218  manager_->GetAspiredVersionsCallback()(kServableName,
219  std::move(aspired_versions));
220  HandlePendingAspiredVersionsRequests();
221  // Unload version 0 and load the new aspired version. Version 1 may or may not
222  // be unloaded (depending on whether load/unload thread pools are used).
223  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
224  InvokePolicyAndExecuteAction();
225  }
226  WaitUntilServableManagerStateIsOneOf(
227  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
228 
229  ServableHandle<int64_t> handle;
230  const Status status = manager_->GetServableHandle(
231  ServableRequest::Latest(kServableName), &handle);
232  TF_ASSERT_OK(status);
233  EXPECT_EQ(kNumVersionsPerServable + 1, *handle);
234 }
235 
236 // Test the case where the latest version of a servable available is 0.
237 TEST_P(AspiredVersionsManagerTest, ServableHandleLatestVersionIsZero) {
238  const char kServableName3[] = "kServableName3";
239 
240  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
241  const ServableId id = {kServableName3, 0};
242  aspired_versions.push_back(CreateAspiredVersion(id));
243  manager_->GetAspiredVersionsCallback()(kServableName3,
244  std::move(aspired_versions));
245  HandlePendingAspiredVersionsRequests();
246 
247  InvokePolicyAndExecuteAction();
248  WaitUntilServableManagerStateIsOneOf(
249  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
250 
251  ServableHandle<int64_t> handle;
252  const Status status = manager_->GetServableHandle(
253  ServableRequest::Latest(kServableName3), &handle);
254  TF_ASSERT_OK(status);
255  EXPECT_EQ(0, *handle);
256  EXPECT_EQ(id, handle.id());
257 }
258 
259 TEST_P(AspiredVersionsManagerTest, ReloadAspiredError) {
260  const char kServableName[] = "kAspiredError";
261  auto callback_fn = manager_->GetAspiredVersionsCallback();
262  // First, load a working Servable under version 1.
263  {
264  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
265  const ServableId id = {kServableName, 1};
266  aspired_versions.push_back(CreateAspiredVersion(id));
267  callback_fn(kServableName, std::move(aspired_versions));
268  HandlePendingAspiredVersionsRequests();
269  InvokePolicyAndExecuteAction();
270  WaitUntilServableManagerStateIsOneOf(
271  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
272  ServableHandle<int64_t> handle;
273  const Status status = manager_->GetServableHandle(
274  ServableRequest::Latest(kServableName), &handle);
275  TF_ASSERT_OK(status);
276  EXPECT_EQ(1, *handle);
277  EXPECT_EQ(id, handle.id());
278  }
279  // Having a failing servable load for version 2.
280  {
281  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
282  const ServableId id = {kServableName, 2};
283  aspired_versions.push_back(CreateErroneousAspiredVersion(id));
284  callback_fn(kServableName, std::move(aspired_versions));
285  HandlePendingAspiredVersionsRequests();
286  InvokePolicyAndExecuteAction();
287  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
288  {ServableState::ManagerState::kEnd});
289  ServableHandle<int64_t> handle;
290  Status status = manager_->GetServableHandle(
291  ServableRequest::Specific(kServableName, 2), &handle);
292  EXPECT_FALSE(status.ok()) << status;
293  }
294  // Attempt to reload servable for version 2.
295  {
296  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
297  const ServableId id = {kServableName, 2};
298  aspired_versions.push_back(CreateAspiredVersion(id));
299  callback_fn(kServableName, std::move(aspired_versions));
300  HandlePendingAspiredVersionsRequests();
301  InvokePolicyAndExecuteAction();
302  if (enable_reload_servables_with_error_) {
303  WaitUntilServableManagerStateIsOneOf(
304  servable_state_monitor_, id,
305  {ServableState::ManagerState::kAvailable});
306  ServableHandle<int64_t> handle;
307  Status status = manager_->GetServableHandle(
308  ServableRequest::Specific(kServableName, 2), &handle);
309  TF_ASSERT_OK(status) << status;
310  } else {
311  // Sleep for 1ms. There's nothing to wait on as the state will not change.
312  Env::Default()->SleepForMicroseconds(1000 /* 1 ms */);
313  ServableHandle<int64_t> handle;
314  Status status = manager_->GetServableHandle(
315  ServableRequest::Specific(kServableName, 2), &handle);
316  EXPECT_FALSE(status.ok()) << status;
317  }
318  }
319 }
320 
321 TEST_P(AspiredVersionsManagerTest, ServableHandleSpecificVersion) {
322  ServableHandle<int64_t> handle;
323  const ServableId id = {kServableName2, 0};
324  const Status status =
325  manager_->GetServableHandle(ServableRequest::FromId(id), &handle);
326  TF_ASSERT_OK(status);
327  EXPECT_EQ(0, *handle);
328  EXPECT_EQ(id, handle.id());
329 }
330 
331 TEST_P(AspiredVersionsManagerTest, ListAvailableServableIds) {
332  const std::vector<ServableId> expected_before = {{kServableName, 0},
333  {kServableName, 1},
334  {kServableName2, 0},
335  {kServableName2, 1}};
336  EXPECT_THAT(manager_->ListAvailableServableIds(),
337  UnorderedElementsAreArray(expected_before));
338 
339  // Set stream kServableName to have servables 7.
340  // This causes 0 & 1 to be unloaded and 7 to be loaded, but 7 errors on load,
341  // so never moves to a loaded state.
342  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
343  const ServableId id = {kServableName, 7};
344  std::unique_ptr<Loader> loader(
345  new FakeLoader(7, errors::Internal("An error.")));
346  aspired_versions.push_back({id, std::move(loader)});
347  manager_->GetAspiredVersionsCallback()(kServableName,
348  std::move(aspired_versions));
349  HandlePendingAspiredVersionsRequests();
350  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
351  InvokePolicyAndExecuteAction();
352  }
353  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
354  {ServableState::ManagerState::kEnd});
355 
356  manager_->GetAspiredVersionsCallback()(kServableName, {});
357  HandlePendingAspiredVersionsRequests();
358  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
359  InvokePolicyAndExecuteAction();
360  }
361  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
362  {kServableName, 0},
363  {ServableState::ManagerState::kEnd});
364  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
365  {kServableName, 1},
366  {ServableState::ManagerState::kEnd});
367 
368  const std::vector<ServableId> expected_after = {{kServableName2, 0},
369  {kServableName2, 1}};
370  EXPECT_THAT(manager_->ListAvailableServableIds(),
371  UnorderedElementsAreArray(expected_after));
372 }
373 
374 TEST_P(AspiredVersionsManagerTest, GetAvailableServableHandles) {
375  // Scoped to destruct handles at the end of it.
376  {
377  const std::map<ServableId, ServableHandle<int64_t>> handles_before =
378  manager_->GetAvailableServableHandles<int64_t>();
379  ASSERT_EQ(kNumVersionsPerServable * 2, handles_before.size());
380 
381  const std::vector<ServableId> expected_ids_before = {{kServableName, 0},
382  {kServableName, 1},
383  {kServableName2, 0},
384  {kServableName2, 1}};
385  for (const ServableId& expected_id : expected_ids_before) {
386  const auto found_it = handles_before.find(expected_id);
387  ASSERT_TRUE(found_it != handles_before.end());
388  EXPECT_EQ(expected_id.version, *found_it->second);
389  }
390  }
391 
392  // Set stream kServableName to have servables 7.
393  // This causes 0 & 1 to be unloaded and 7 to be loaded, but 7 errors on load,
394  // so never moves to a loaded state.
395  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
396  const ServableId id = {kServableName, 7};
397  std::unique_ptr<Loader> loader(
398  new FakeLoader(7, errors::Internal("An error.")));
399  aspired_versions.push_back({id, std::move(loader)});
400  manager_->GetAspiredVersionsCallback()(kServableName,
401  std::move(aspired_versions));
402  HandlePendingAspiredVersionsRequests();
403  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
404  InvokePolicyAndExecuteAction();
405  }
406  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
407  {ServableState::ManagerState::kEnd});
408 
409  manager_->GetAspiredVersionsCallback()(kServableName, {});
410  HandlePendingAspiredVersionsRequests();
411  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
412  InvokePolicyAndExecuteAction();
413  }
414  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
415  {kServableName, 0},
416  {ServableState::ManagerState::kEnd});
417  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
418  {kServableName, 1},
419  {ServableState::ManagerState::kEnd});
420  {
421  const std::map<ServableId, ServableHandle<int64_t>> handles_after =
422  manager_->GetAvailableServableHandles<int64_t>();
423  ASSERT_EQ(kNumVersionsPerServable, handles_after.size());
424 
425  const std::vector<ServableId> expected_ids_after = {{kServableName2, 0},
426  {kServableName2, 1}};
427  for (const ServableId& expected_id : expected_ids_after) {
428  const auto found_it = handles_after.find(expected_id);
429  ASSERT_TRUE(found_it != handles_after.end());
430  EXPECT_EQ(expected_id.version, *found_it->second);
431  }
432  }
433 }
434 
435 TEST_P(AspiredVersionsManagerTest, GetAvailableServableHandlesWrongType) {
436  const std::map<ServableId, ServableHandle<int>> wrong_type_handles =
437  manager_->GetAvailableServableHandles<int>();
438  EXPECT_EQ(0, wrong_type_handles.size());
439 }
440 
441 TEST_P(AspiredVersionsManagerTest, AspiredRemovedFull) {
442  // Scoped so that the handle is destructed at the end, and the harness is
443  // destructed when we run the manager looping thread.
444  {
445  ServableHandle<int64_t> handle;
446  const Status status = manager_->GetServableHandle(
447  ServableRequest::Latest(kServableName), &handle);
448  TF_ASSERT_OK(status);
449  EXPECT_EQ(1, *handle);
450  }
451 
452  manager_->GetAspiredVersionsCallback()(kServableName, {});
453  HandlePendingAspiredVersionsRequests();
454 
455  const int num_fake_loaders_before = FakeLoader::num_fake_loaders();
456  for (int i = 0; i < kNumVersionsPerServable; ++i) {
457  InvokePolicyAndExecuteAction();
458  }
459  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
460  {kServableName, 0},
461  {ServableState::ManagerState::kEnd});
462  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
463  {kServableName, 1},
464  {ServableState::ManagerState::kEnd});
465  FlushServables();
466  const int num_fake_loaders_after = FakeLoader::num_fake_loaders();
467  EXPECT_EQ(kNumVersionsPerServable,
468  num_fake_loaders_before - num_fake_loaders_after);
469 
470  ServableHandle<int64_t> missing_handle;
471  const Status missing_status = manager_->GetServableHandle(
472  ServableRequest::Latest(kServableName), &missing_handle);
473  ASSERT_FALSE(missing_status.ok());
474  EXPECT_EQ(error::NOT_FOUND, missing_status.code());
475 }
476 
477 TEST_P(AspiredVersionsManagerTest, AspiredRemovedPartial) {
478  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
479  aspired_versions.push_back(CreateAspiredVersion({kServableName, 0}));
480  manager_->GetAspiredVersionsCallback()(kServableName,
481  std::move(aspired_versions));
482  HandlePendingAspiredVersionsRequests();
483 
484  InvokePolicyAndExecuteAction();
485  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
486  {kServableName, 1},
487  {ServableState::ManagerState::kEnd});
488 
489  // Version 0 should remain available in the manager.
490  ServableHandle<int64_t> v0_handle;
491  const Status v0_status = manager_->GetServableHandle(
492  ServableRequest::Specific(kServableName, 0), &v0_handle);
493  TF_ASSERT_OK(v0_status);
494  EXPECT_EQ(0, *v0_handle);
495 
496  // Version 1 should no longer be available.
497  ServableHandle<int64_t> v1_handle;
498  const Status v1_status = manager_->GetServableHandle(
499  ServableRequest::Specific(kServableName, 1), &v1_handle);
500  ASSERT_FALSE(v1_status.ok());
501  EXPECT_EQ(error::NOT_FOUND, v1_status.code());
502 }
503 
504 TEST_P(AspiredVersionsManagerTest, RevertToSmallerVersionNumber) {
505  // Initially, versions 0 and 1 of kServableName are loaded.
506  std::set<int64_t> initial_versions;
507  for (const ServableId& id : manager_->ListAvailableServableIds()) {
508  if (id.name == kServableName) {
509  initial_versions.insert(id.version);
510  }
511  }
512  ASSERT_THAT(initial_versions, UnorderedElementsAre(0, 1));
513 
514  // Unload version 0, s.t. only version 1 is loaded.
515  std::vector<ServableData<std::unique_ptr<Loader>>> initial_aspired_versions;
516  initial_aspired_versions.push_back(CreateAspiredVersion({kServableName, 1}));
517  manager_->GetAspiredVersionsCallback()(kServableName,
518  std::move(initial_aspired_versions));
519  HandlePendingAspiredVersionsRequests();
520  InvokePolicyAndExecuteAction();
521  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
522  {kServableName, 0},
523  {ServableState::ManagerState::kEnd});
524  FlushServables();
525 
526  // Now, switch to version 0 (dropping version 1).
527  std::vector<ServableData<std::unique_ptr<Loader>>> new_aspired_versions;
528  new_aspired_versions.push_back(CreateAspiredVersion({kServableName, 0}));
529  manager_->GetAspiredVersionsCallback()(kServableName,
530  std::move(new_aspired_versions));
531  HandlePendingAspiredVersionsRequests();
532  Notification done_transitioning;
533  std::unique_ptr<Thread> transition_servables(
534  Env::Default()->StartThread({}, "TransitionServables", [&]() {
535  while (!done_transitioning.HasBeenNotified()) {
536  InvokePolicyAndExecuteAction();
537  Env::Default()->SleepForMicroseconds(1000 /* 1 ms */);
538  }
539  }));
540  WaitUntilServableManagerStateIsOneOf(
541  servable_state_monitor_, {kServableName, 0},
542  {ServableState::ManagerState::kAvailable});
543  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
544  {kServableName, 1},
545  {ServableState::ManagerState::kEnd});
546  done_transitioning.Notify();
547 
548  // Version 0 should be available.
549  ServableHandle<int64_t> v0_handle;
550  const Status v0_status = manager_->GetServableHandle(
551  ServableRequest::Specific(kServableName, 0), &v0_handle);
552  TF_ASSERT_OK(v0_status);
553  EXPECT_EQ(0, *v0_handle);
554 
555  // Version 1 should not be available.
556  ServableHandle<int64_t> v1_handle;
557  const Status v1_status = manager_->GetServableHandle(
558  ServableRequest::Specific(kServableName, 1), &v1_handle);
559  ASSERT_FALSE(v1_status.ok());
560  EXPECT_EQ(error::NOT_FOUND, v1_status.code());
561 }
562 
563 TEST_P(AspiredVersionsManagerTest, AspiredAndManageStateLoad) {
564  const ServableId id = {kServableName, 2};
565  ServableHandle<int64_t> not_found_handle;
566  const Status not_found_status = manager_->GetServableHandle(
567  ServableRequest::FromId(id), &not_found_handle);
568  ASSERT_FALSE(not_found_status.ok()) << not_found_status;
569  EXPECT_EQ(error::NOT_FOUND, not_found_status.code());
570 
571  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
572  aspired_versions.push_back(CreateAspiredVersion(id));
573  manager_->GetAspiredVersionsCallback()(kServableName,
574  std::move(aspired_versions));
575  HandlePendingAspiredVersionsRequests();
576 
577  ServableHandle<int64_t> not_ready_handle;
578  const Status not_ready_status = manager_->GetServableHandle(
579  ServableRequest::FromId(id), &not_ready_handle);
580  ASSERT_FALSE(not_ready_status.ok()) << not_ready_status;
581  EXPECT_EQ(error::NOT_FOUND, not_ready_status.code());
582 
583  // Unload version 0 and load the new aspired version. Version 1 may or may not
584  // be unloaded (depending on whether load/unload thread pools are used).
585  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
586  InvokePolicyAndExecuteAction();
587  }
588  WaitUntilServableManagerStateIsOneOf(
589  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
590 
591  ServableHandle<int64_t> handle;
592  const Status status =
593  manager_->GetServableHandle(ServableRequest::FromId(id), &handle);
594  TF_ASSERT_OK(status);
595  EXPECT_EQ(2, *handle);
596 }
597 
598 TEST_P(AspiredVersionsManagerTest, AspiredAndManageStateUnload) {
599  {
600  ServableHandle<int64_t> handle;
601  const Status status = manager_->GetServableHandle(
602  ServableRequest::Specific(kServableName, 0), &handle);
603  TF_ASSERT_OK(status);
604  EXPECT_EQ(0, *handle);
605  }
606 
607  manager_->GetAspiredVersionsCallback()(kServableName, {});
608  HandlePendingAspiredVersionsRequests();
609 
610  for (int i = 0; i < kNumVersionsPerServable; ++i) {
611  InvokePolicyAndExecuteAction();
612  }
613  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
614  {kServableName, 0},
615  {ServableState::ManagerState::kEnd});
616  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
617  {kServableName, 1},
618  {ServableState::ManagerState::kEnd});
619 
620  ServableHandle<int64_t> not_found_handle;
621  const Status not_found_status = manager_->GetServableHandle(
622  ServableRequest::Specific(kServableName, 0), &not_found_handle);
623  ASSERT_FALSE(not_found_status.ok()) << not_found_status;
624  EXPECT_EQ(error::NOT_FOUND, not_found_status.code());
625 }
626 
627 // The manager prefers unloading over loading when deciding between different
628 // servable actions. This behaviour is tested here.
629 TEST_P(AspiredVersionsManagerTest, ManagerPrefersUnloadOverLoad) {
630  ServableHandle<int64_t> not_found_2_handle;
631  Status not_found_2_status = manager_->GetServableHandle(
632  ServableRequest::Specific(kServableName2, 2), &not_found_2_handle);
633  ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
634  EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
635 
636  {
637  ServableHandle<int64_t> found_0_handle;
638  const Status found_0_status = manager_->GetServableHandle(
639  ServableRequest::Specific(kServableName, 0), &found_0_handle);
640  TF_ASSERT_OK(found_0_status);
641  EXPECT_EQ(0, *found_0_handle);
642  }
643 
644  // We want to unload version 0 of the first servable stream and load version 2
645  // of the second stream.
646  struct {
647  StringPiece name;
648  int start;
649  int end;
650  } servable_aspired_list[2] = {{kServableName, 1, 1}, {kServableName2, 0, 2}};
651  for (const auto& servable_aspired : servable_aspired_list) {
652  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
653  for (int i = servable_aspired.start; i <= servable_aspired.end; ++i) {
654  const ServableId id = {string(servable_aspired.name), i};
655  aspired_versions.push_back(CreateAspiredVersion(id));
656  }
657  manager_->GetAspiredVersionsCallback()(servable_aspired.name,
658  std::move(aspired_versions));
659  HandlePendingAspiredVersionsRequests();
660  }
661 
662  // The manager prefers to unload a servable before loading a servable, so it
663  // should prefer to unload version 0 of the first servable stream.
664  InvokePolicyAndExecuteAction();
665  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
666  {kServableName, 0},
667  {ServableState::ManagerState::kEnd});
668 
669  ServableHandle<int64_t> not_found_0_handle;
670  const Status not_found_0_status = manager_->GetServableHandle(
671  ServableRequest::Specific(kServableName, 0), &not_found_0_handle);
672  ASSERT_FALSE(not_found_0_status.ok()) << not_found_0_status;
673  EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
674 
675  not_found_2_status = manager_->GetServableHandle(
676  ServableRequest::Specific(kServableName2, 2), &not_found_2_handle);
677  ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
678  EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
679 
680  // Now it should load version 2 of the second servable stream.
681  InvokePolicyAndExecuteAction();
682  WaitUntilServableManagerStateIsOneOf(
683  servable_state_monitor_, {kServableName2, 2},
684  {ServableState::ManagerState::kAvailable});
685 
686  ServableHandle<int64_t> found_2_handle;
687  const Status found_2_status = manager_->GetServableHandle(
688  ServableRequest::Specific(kServableName2, 2), &found_2_handle);
689  TF_ASSERT_OK(found_2_status);
690  EXPECT_EQ(2, *found_2_handle);
691 }
692 
693 TEST_P(AspiredVersionsManagerTest, CustomSortActions) {
694  test_util::AspiredVersionsManagerTestAccess(manager_.get())
695  .SetCustomSortActions(
696  [](const AspiredVersionPolicy::ServableAction& lhs,
697  const AspiredVersionPolicy::ServableAction& rhs) -> bool {
698  // Prefer kServableName2 over anything else; note the impl needs to
699  // be a valid strict-weak ordering
700  bool lhs_is_servable_2 = lhs.id.name == kServableName2;
701  bool rhs_is_servable_2 = rhs.id.name == kServableName2;
702  if (lhs_is_servable_2 != rhs_is_servable_2) {
703  return lhs_is_servable_2;
704  }
705  return false;
706  });
707 
708  {
709  ServableHandle<int64_t> not_found_2_handle;
710  Status not_found_2_status = manager_->GetServableHandle(
711  ServableRequest::Specific(kServableName2, 2), &not_found_2_handle);
712  ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
713  EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
714  }
715 
716  {
717  ServableHandle<int64_t> found_0_handle;
718  Status found_0_status = manager_->GetServableHandle(
719  ServableRequest::Specific(kServableName, 0), &found_0_handle);
720  TF_ASSERT_OK(found_0_status);
721  EXPECT_EQ(0, *found_0_handle);
722  }
723 
724  // We want to unload version 0 of the first servable stream and load version 2
725  // of the second stream.
726  struct {
727  StringPiece name;
728  int start;
729  int end;
730  } servable_aspired_list[2] = {{kServableName, 1, 1}, {kServableName2, 0, 2}};
731  for (const auto& servable_aspired : servable_aspired_list) {
732  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
733  for (int i = servable_aspired.start; i <= servable_aspired.end; ++i) {
734  const ServableId id = {string(servable_aspired.name), i};
735  aspired_versions.push_back(CreateAspiredVersion(id));
736  }
737  manager_->GetAspiredVersionsCallback()(servable_aspired.name,
738  std::move(aspired_versions));
739  HandlePendingAspiredVersionsRequests();
740  }
741 
742  // By default, the manager prefers to unload a servable before loading a
743  // servable, see ManagerPrefersUnloadOverLoad test case above; but here our
744  // custom sort order prefers kServableName2
745  InvokePolicyAndExecuteAction();
746  WaitUntilServableManagerStateIsOneOf(
747  servable_state_monitor_, {kServableName2, 2},
748  {ServableState::ManagerState::kAvailable});
749 
750  {
751  ServableHandle<int64_t> found_0_handle;
752  Status found_0_status = manager_->GetServableHandle(
753  ServableRequest::Specific(kServableName, 0), &found_0_handle);
754  TF_ASSERT_OK(found_0_status);
755  EXPECT_EQ(0, *found_0_handle);
756  }
757 
758  {
759  ServableHandle<int64_t> found_2_handle;
760  const Status found_2_status = manager_->GetServableHandle(
761  ServableRequest::Specific(kServableName2, 2), &found_2_handle);
762  TF_ASSERT_OK(found_2_status);
763  EXPECT_EQ(2, *found_2_handle);
764  }
765 
766  // Now it would unload the first
767  InvokePolicyAndExecuteAction();
768  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
769  {kServableName, 0},
770  {ServableState::ManagerState::kEnd});
771 
772  {
773  ServableHandle<int64_t> not_found_0_handle;
774  const Status not_found_0_status = manager_->GetServableHandle(
775  ServableRequest::Specific(kServableName, 0), &not_found_0_handle);
776  ASSERT_FALSE(not_found_0_status.ok()) << not_found_0_status;
777  EXPECT_EQ(error::NOT_FOUND, not_found_0_status.code());
778  }
779 }
780 
781 // Test to ensure the manager doesn't try to load or serve an incoming erroneous
782 // aspired-version entry.
783 TEST_P(AspiredVersionsManagerTest, ErroneousAspiredVersion) {
784  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
785  aspired_versions.push_back(CreateErroneousAspiredVersion({kServableName, 3}));
786  manager_->GetAspiredVersionsCallback()(kServableName,
787  std::move(aspired_versions));
788  HandlePendingAspiredVersionsRequests();
789 
790  ServableHandle<int64_t> handle;
791  Status status = manager_->GetServableHandle(
792  ServableRequest::Specific(kServableName, 3), &handle);
793  EXPECT_FALSE(status.ok()) << status;
794 
795  InvokePolicyAndExecuteAction();
796 
797  status = manager_->GetServableHandle(
798  ServableRequest::Specific(kServableName, 3), &handle);
799  EXPECT_FALSE(status.ok()) << status;
800 }
801 
802 // Test to ensure that the deletion of a loader/servable occurs in a manager
803 // thread, and not a request thread.
804 TEST_P(AspiredVersionsManagerTest, DestructOnNonServingThread) {
805  std::unique_ptr<ServableHandle<int64_t>> latest_handle(
806  new ServableHandle<int64_t>());
807  const Status status = manager_->GetServableHandle(
808  ServableRequest::Latest(kServableName), latest_handle.get());
809  TF_ASSERT_OK(status);
810  EXPECT_EQ(1, **latest_handle);
811 
812  manager_->GetAspiredVersionsCallback()(kServableName, {});
813  HandlePendingAspiredVersionsRequests();
814 
815  Notification done_unload_servable;
816  std::unique_ptr<Thread> unload_servable(
817  Env::Default()->StartThread({}, "UnloadServable", [&]() {
818  // Unload the servable.
819  for (int i = 0; i < kNumVersionsPerServable; ++i) {
820  InvokePolicyAndExecuteAction();
821  }
822  WaitUntilServableManagerStateIsOneOf(
823  servable_state_monitor_, {kServableName, 0},
824  {ServableState::ManagerState::kEnd});
825  FlushServables();
826  // The servable has been deleted in this thread if there is no
827  // thread-pool for unload.
828  if (thread_pool_sizes_.num_unload_threads == 0) {
829  EXPECT_TRUE(FakeLoader::was_deleted_in_this_thread());
830  }
831  done_unload_servable.Notify();
832  }));
833 
834  // This will unblock the UnloadServable.
835  latest_handle.reset();
836  done_unload_servable.WaitForNotification();
837  // The servable wasn't deleted in this thread.
838  EXPECT_FALSE(FakeLoader::was_deleted_in_this_thread());
839 }
840 
841 MATCHER_P(EqualsServableState, servable_state, servable_state.DebugString()) {
842  if (arg == servable_state) {
843  return true;
844  }
845  *result_listener << arg.DebugString();
846  return false;
847 }
848 
849 TEST_P(AspiredVersionsManagerTest, EventBusErroneousVersion) {
850  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
851  const ServableId id = {kServableName, 3};
852  aspired_versions.push_back(
853  ServableData<std::unique_ptr<Loader>>(id, errors::Unknown("error")));
854  manager_->GetAspiredVersionsCallback()(kServableName,
855  std::move(aspired_versions));
856  HandlePendingAspiredVersionsRequests();
857 
858  const ServableState expected_published_state = {
859  id, ServableState::ManagerState::kEnd, errors::Unknown("error")};
860  EXPECT_THAT(*servable_state_monitor_.GetState(id),
861  EqualsServableState(expected_published_state));
862 }
863 
864 TEST_P(AspiredVersionsManagerTest, EventBusErrorOnLoad) {
865  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
866  const ServableId id = {kServableName, 7};
867  std::unique_ptr<Loader> loader(
868  new FakeLoader(7, errors::Internal("Error on load.")));
869  aspired_versions.push_back({id, std::move(loader)});
870  manager_->GetAspiredVersionsCallback()(kServableName,
871  std::move(aspired_versions));
872  HandlePendingAspiredVersionsRequests();
873 
874  const ServableState start_state = {id, ServableState::ManagerState::kStart,
875  OkStatus()};
876  EXPECT_THAT(*servable_state_monitor_.GetState(id),
877  EqualsServableState(start_state));
878 
879  // Unload version 0 and load the new aspired version. Version 1 may or may not
880  // be unloaded (depending on whether load/unload thread pools are used).
881  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
882  InvokePolicyAndExecuteAction();
883  }
884  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
885  {ServableState::ManagerState::kEnd});
886 
887  const ServableState error_state = {id, ServableState::ManagerState::kEnd,
888  errors::Internal("Error on load.")};
889  EXPECT_THAT(*servable_state_monitor_.GetState(id),
890  EqualsServableState(error_state));
891 }
892 
893 TEST_P(AspiredVersionsManagerTest, EventBusServableLifecycle) {
894  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
895  const ServableId id = {kServableName, 7};
896  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
897  aspired_versions.push_back({id, std::unique_ptr<Loader>(loader)});
898  manager_->GetAspiredVersionsCallback()(kServableName,
899  std::move(aspired_versions));
900  HandlePendingAspiredVersionsRequests();
901 
902  const ServableState start_state = {id, ServableState::ManagerState::kStart,
903  OkStatus()};
904  EXPECT_THAT(*servable_state_monitor_.GetState(id),
905  EqualsServableState(start_state));
906 
907  Notification load_called;
908  Notification load_continue;
909  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
910  .WillOnce(InvokeWithoutArgs([&]() {
911  load_called.Notify();
912  load_continue.WaitForNotification();
913  return OkStatus();
914  }));
915 
916  std::unique_ptr<Thread> load_unload_thread(
917  Env::Default()->StartThread(ThreadOptions(), "LoadUnloadThread", [&]() {
918  // Unload version 0 and load the new aspired version. Version 1 may or
919  // may not be unloaded (depending on whether load/unload thread pools
920  // are used).
921  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
922  InvokePolicyAndExecuteAction();
923  }
924  }));
925 
926  load_called.WaitForNotification();
927 
928  const ServableState loading_state = {
929  id, ServableState::ManagerState::kLoading, OkStatus()};
930  EXPECT_THAT(*servable_state_monitor_.GetState(id),
931  EqualsServableState(loading_state));
932 
933  load_continue.Notify();
934  WaitUntilServableManagerStateIsOneOf(
935  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
936 
937  const ServableState available_state = {
938  id, ServableState::ManagerState::kAvailable, OkStatus()};
939  EXPECT_THAT(*servable_state_monitor_.GetState(id),
940  EqualsServableState(available_state));
941 
942  manager_->GetAspiredVersionsCallback()(kServableName, {});
943  HandlePendingAspiredVersionsRequests();
944 
945  Notification unload_called;
946  Notification unload_continue;
947  EXPECT_CALL(*loader, Unload()).WillOnce(Invoke([&]() {
948  unload_called.Notify();
949  unload_continue.WaitForNotification();
950  }));
951 
952  std::unique_ptr<Thread> unload_thread(
953  Env::Default()->StartThread(ThreadOptions(), "UnloadThread", [&]() {
954  // Call InvokePolicyAndExecuteAction() twice to unload version 1 and the
955  // new version, in case version 1 has not been unloaded previously.
956  InvokePolicyAndExecuteAction();
957  InvokePolicyAndExecuteAction();
958  }));
959 
960  unload_called.WaitForNotification();
961 
962  const ServableState unloading_state = {
963  id, ServableState::ManagerState::kUnloading, OkStatus()};
964  EXPECT_THAT(*servable_state_monitor_.GetState(id),
965  EqualsServableState(unloading_state));
966 
967  unload_continue.Notify();
968  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
969  {ServableState::ManagerState::kEnd});
970 
971  const ServableState end_state = {
972  {kServableName, 7}, ServableState::ManagerState::kEnd, OkStatus()};
973  EXPECT_THAT(*servable_state_monitor_.GetState(id),
974  EqualsServableState(end_state));
975 }
976 
977 // Tests whether there are any errors if we don't have an event bus configured.
978 TEST_P(AspiredVersionsManagerTest, NoEventBus) {
979  AspiredVersionsManager::Options options;
980  // The state manager thread won't be run automatically.
981  options.manage_state_interval_micros = -1;
982  options.env = Env::Default();
983  options.aspired_version_policy.reset(new AvailabilityPreservingPolicy());
984  std::unique_ptr<AspiredVersionsManager> aspired_versions_manager;
985  TF_ASSERT_OK(AspiredVersionsManager::Create(std::move(options),
986  &aspired_versions_manager));
987 
988  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
989  const ServableId id = {kServableName, 7};
990  std::unique_ptr<Loader> loader(new FakeLoader(7));
991  aspired_versions.push_back({id, std::move(loader)});
992  aspired_versions_manager->GetAspiredVersionsCallback()(
993  kServableName, std::move(aspired_versions));
994  HandlePendingAspiredVersionsRequests();
995 }
996 
997 TEST_P(AspiredVersionsManagerTest, RetryOnLoadErrorFinallySucceeds) {
998  CHECK_GE(max_num_load_retries_, 1);
999  const ServableId id = {kServableName, 7};
1000  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1001  // We succeed on the last load, before the manager gives up.
1002  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1003  .WillOnce(Return(errors::Internal("Error on load.")))
1004  .WillOnce(Return(OkStatus()));
1005 
1006  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1007  aspired_versions.push_back({id, std::unique_ptr<Loader>(loader)});
1008  manager_->GetAspiredVersionsCallback()(kServableName,
1009  std::move(aspired_versions));
1010  HandlePendingAspiredVersionsRequests();
1011 
1012  // Unload version 0 and load the new aspired version. Version 1 may or may not
1013  // be unloaded (depending on whether load/unload thread pools are used).
1014  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1015  InvokePolicyAndExecuteAction();
1016  }
1017  WaitUntilServableManagerStateIsOneOf(
1018  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
1019 
1020  const ServableState available_state = {
1021  id, ServableState::ManagerState::kAvailable, OkStatus()};
1022  EXPECT_THAT(*servable_state_monitor_.GetState(id),
1023  EqualsServableState(available_state));
1024 }
1025 
1026 TEST_P(AspiredVersionsManagerTest, RetryOnLoadErrorFinallyFails) {
1027  CHECK_GE(max_num_load_retries_, 1);
1028 
1029  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1030  const ServableId id = {kServableName, 7};
1031  // We always fail.
1032  std::unique_ptr<Loader> loader(
1033  new FakeLoader(7, errors::Internal("Error on load.")));
1034  aspired_versions.push_back({id, std::move(loader)});
1035  manager_->GetAspiredVersionsCallback()(kServableName,
1036  std::move(aspired_versions));
1037  HandlePendingAspiredVersionsRequests();
1038 
1039  // Unload version 0 and load the new aspired version. Version 1 may or may not
1040  // be unloaded (depending on whether load/unload thread pools are used).
1041  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1042  InvokePolicyAndExecuteAction();
1043  }
1044  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1045  {ServableState::ManagerState::kEnd});
1046 
1047  const ServableState error_state = {id, ServableState::ManagerState::kEnd,
1048  errors::Internal("Error on load.")};
1049  EXPECT_THAT(*servable_state_monitor_.GetState(id),
1050  EqualsServableState(error_state));
1051 }
1052 
1053 // Tests the interaction between AspiredVersionsManager and the
1054 // AvailabilityPreservingPolicy.
1055 // Specifically, we want to make sure that the manager will not try to unload
1056 // all serving versions that are no longer aspired if the new aspired version
1057 // was not able to start serving.
1058 TEST_P(AspiredVersionsManagerTest, AspireErrorDontUnload) {
1059  const std::vector<ServableId> expected_before = {{kServableName, 0},
1060  {kServableName, 1},
1061  {kServableName2, 0},
1062  {kServableName2, 1}};
1063  EXPECT_THAT(manager_->ListAvailableServableIds(),
1064  UnorderedElementsAreArray(expected_before));
1065 
1066  // Set stream kServableName to have servable 7.
1067  // This causes 0 & 1 to be set to not aspired and 7 to be loaded, but 7 errors
1068  // on load, so never moves to a loaded state.
1069  {
1070  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1071  const ServableId id = {kServableName, 7};
1072  std::unique_ptr<Loader> loader(
1073  new FakeLoader(7, errors::Internal("An error.")));
1074  aspired_versions.push_back({id, std::move(loader)});
1075  manager_->GetAspiredVersionsCallback()(kServableName,
1076  std::move(aspired_versions));
1077  HandlePendingAspiredVersionsRequests();
1078 
1079  // Will unload version 0.
1080  InvokePolicyAndExecuteAction();
1081  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1082  {kServableName, 0},
1083  {ServableState::ManagerState::kEnd});
1084 
1085  // Will try to load version 7 and fail.
1086  InvokePolicyAndExecuteAction();
1087  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1088  {ServableState::ManagerState::kEnd});
1089  }
1090 
1091  // For kServableName, version 0 has been unloaded. For kServableName2, both
1092  // versions should still be loaded.
1093  const std::vector<ServableId> expected_after_first_load = {
1094  {kServableName, 1}, {kServableName2, 0}, {kServableName2, 1}};
1095  EXPECT_THAT(manager_->ListAvailableServableIds(),
1096  UnorderedElementsAreArray(expected_after_first_load));
1097 
1098  // Now successfully loading a new version should allow the older versions to
1099  // be unloaded.
1100  {
1101  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1102  const ServableId id = {kServableName, 8};
1103  std::unique_ptr<Loader> loader(new FakeLoader(8));
1104  aspired_versions.push_back({id, std::move(loader)});
1105  manager_->GetAspiredVersionsCallback()(kServableName,
1106  std::move(aspired_versions));
1107  HandlePendingAspiredVersionsRequests();
1108 
1109  // Will try to load version 8 and succeed.
1110  InvokePolicyAndExecuteAction();
1111  WaitUntilServableManagerStateIsOneOf(
1112  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
1113 
1114  // Will unload version 1.
1115  InvokePolicyAndExecuteAction();
1116  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1117  {kServableName, 1},
1118  {ServableState::ManagerState::kEnd});
1119  }
1120 }
1121 
1122 TEST_P(AspiredVersionsManagerTest, UnaspireThenImmediatelyReaspire) {
1123  // This test exercises a scenario in which a servable has been unaspired, and
1124  // while it is still being managed (e.g. loading, serving or unloading) it
1125  // gets reaspired (with a new loader). The manager should wait for the
1126  // original loader to get taken down via the normal process for unaspired
1127  // loaders, and then proceed to bring up the new loader.
1128 
1129  const ServableId id = {kServableName, 7};
1130 
1131  std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1132  test_util::MockLoader* first_loader = new NiceMock<test_util::MockLoader>();
1133  first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1134  EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{id}))
1135  .WillOnce(Return(OkStatus()));
1136  manager_->GetAspiredVersionsCallback()(kServableName,
1137  std::move(first_aspired_versions));
1138  HandlePendingAspiredVersionsRequests();
1139 
1140  // Wait for verion 0 to be unloaded and the new aspired version to be loaded.
1141  // If we don't wait, the first_loader_handle below may be obtained before
1142  // the loading or unloading finishes, which may block the loading or
1143  // unloading.
1144  InvokePolicyAndExecuteAction();
1145  InvokePolicyAndExecuteAction();
1146  WaitUntilServableManagerStateIsOneOf(
1147  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
1148  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1149  {kServableName, 0},
1150  {ServableState::ManagerState::kEnd});
1151 
1152  // Pin 'first_loader' in the manager by holding a handle to its servable.
1153  int servable = 42;
1154  EXPECT_CALL(*first_loader, servable()).WillOnce(InvokeWithoutArgs([&]() {
1155  return AnyPtr{&servable};
1156  }));
1157  auto first_loader_handle =
1158  std::unique_ptr<ServableHandle<int>>(new ServableHandle<int>);
1159  TF_ASSERT_OK(manager_->GetServableHandle(ServableRequest::FromId(id),
1160  first_loader_handle.get()));
1161 
1162  // Now, we'll un-aspire the servable, and then re-aspire it with a new loader.
1163  // The manager should wait until it is able to unload the first loader, then
1164  // bring up the second loader.
1165 
1166  Notification first_unload_called;
1167  EXPECT_CALL(*first_loader, Unload()).WillOnce(InvokeWithoutArgs([&]() {
1168  first_unload_called.Notify();
1169  }));
1170 
1171  std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1172  manager_->GetAspiredVersionsCallback()(kServableName,
1173  std::move(empty_aspired_versions));
1174  HandlePendingAspiredVersionsRequests();
1175 
1176  // The following thread will block trying to unload the first loader, while we
1177  // hold the handle.
1178  std::unique_ptr<Thread> unload_thread(
1179  Env::Default()->StartThread(ThreadOptions(), "UnloadThread", [&]() {
1180  // Unload version 1 and the newly un-aspired version.
1181  InvokePolicyAndExecuteAction();
1182  InvokePolicyAndExecuteAction();
1183  }));
1184 
1185  // Re-aspire the servable with a fresh loader.
1186  std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1187  test_util::MockLoader* second_loader = new NiceMock<test_util::MockLoader>();
1188  second_aspired_versions.push_back(
1189  {id, std::unique_ptr<Loader>(second_loader)});
1190  Notification second_load_called;
1191  EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{id}))
1192  .WillOnce(InvokeWithoutArgs([&]() {
1193  second_load_called.Notify();
1194  return OkStatus();
1195  }));
1196  manager_->GetAspiredVersionsCallback()(kServableName,
1197  std::move(second_aspired_versions));
1198 
1199  // Run the manager's background logic in a loop. Nothing should happen for now
1200  // because the first loader is pinned.
1201  std::unique_ptr<Thread> reaspire_thread(
1202  Env::Default()->StartThread(ThreadOptions(), "ReaspireThread", [&]() {
1203  while (!second_load_called.HasBeenNotified()) {
1204  FlushServables();
1205  HandlePendingAspiredVersionsRequests();
1206  InvokePolicyAndExecuteAction();
1207  Env::Default()->SleepForMicroseconds(1000 /* 1 ms */);
1208  }
1209  }));
1210  Env::Default()->SleepForMicroseconds(50 * 1000 /* 50 ms */);
1211  EXPECT_FALSE(first_unload_called.HasBeenNotified());
1212  EXPECT_FALSE(second_load_called.HasBeenNotified());
1213 
1214  // Unpin the first loader. The manager should unload the first loader and
1215  // bring up the second loader.
1216  first_loader_handle = nullptr;
1217  first_unload_called.WaitForNotification();
1218  second_load_called.WaitForNotification();
1219 }
1220 
1221 TEST_P(AspiredVersionsManagerTest,
1222  UnaspireFailedServableThenImmediatelyReaspire) {
1223  // Like UnaspireThenImmediatelyReaspire, but covers the case in which the
1224  // servable fails to load the first time it is aspired.
1225 
1226  const ServableId id = {kServableName, 7};
1227 
1228  std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1229  test_util::MockLoader* first_loader = new NiceMock<test_util::MockLoader>();
1230  first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1231  EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{id}))
1232  .WillRepeatedly(Return(
1233  Status(static_cast<tsl::errors::Code>(absl::StatusCode::kUnknown),
1234  "first load failing")));
1235  manager_->GetAspiredVersionsCallback()(kServableName,
1236  std::move(first_aspired_versions));
1237  HandlePendingAspiredVersionsRequests();
1238  // Unload version 0 and load the new aspired version. Version 1 may or may not
1239  // be unloaded (depending on whether load/unload thread pools are used).
1240  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1241  InvokePolicyAndExecuteAction();
1242  }
1243  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1244  {ServableState::ManagerState::kEnd});
1245 
1246  // Now, we'll un-aspire the servable, and then re-aspire it with a new loader.
1247  // The manager should wait until it is able to flush the first loader, then
1248  // bring up the second loader.
1249 
1250  std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1251  manager_->GetAspiredVersionsCallback()(kServableName,
1252  std::move(empty_aspired_versions));
1253  HandlePendingAspiredVersionsRequests();
1254 
1255  // Re-aspire the servable with a fresh loader.
1256  std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1257  test_util::MockLoader* second_loader = new NiceMock<test_util::MockLoader>();
1258  second_aspired_versions.push_back(
1259  {id, std::unique_ptr<Loader>(second_loader)});
1260  Notification second_load_called;
1261  EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{id}))
1262  .WillOnce(InvokeWithoutArgs([&]() {
1263  second_load_called.Notify();
1264  return OkStatus();
1265  }));
1266  manager_->GetAspiredVersionsCallback()(kServableName,
1267  std::move(second_aspired_versions));
1268 
1269  // Run the manager's background logic in a loop, but sans FlushServables().
1270  // Nothing should happen for now because the first loader isn't flushed.
1271  std::unique_ptr<Thread> reaspire_thread(
1272  Env::Default()->StartThread(ThreadOptions(), "ReaspireThread", [&]() {
1273  while (!second_load_called.HasBeenNotified()) {
1274  HandlePendingAspiredVersionsRequests();
1275  InvokePolicyAndExecuteAction();
1276  Env::Default()->SleepForMicroseconds(1000 /* 1 ms */);
1277  }
1278  }));
1279  Env::Default()->SleepForMicroseconds(50 * 1000 /* 50 ms */);
1280  EXPECT_FALSE(second_load_called.HasBeenNotified());
1281 
1282  // Flush the first loader. The manager should finally bring up the second
1283  // loader.
1284  FlushServables();
1285  second_load_called.WaitForNotification();
1286 }
1287 
1288 TEST_P(AspiredVersionsManagerTest, UnaspireNewServableThenImmediatelyReaspire) {
1289  // Like UnaspireThenImmediatelyReaspire, but covers the case in which the
1290  // servable is in state kNew when it gets unaspired.
1291  // (Regression test for b/27766674.)
1292 
1293  const ServableId id = {kServableName, 7};
1294 
1295  std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1296  test_util::MockLoader* first_loader = new NiceMock<test_util::MockLoader>();
1297  EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{id})).Times(0);
1298  first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1299  manager_->GetAspiredVersionsCallback()(kServableName,
1300  std::move(first_aspired_versions));
1301  HandlePendingAspiredVersionsRequests();
1302  // (We *don't* call InvokePolicyAndExecuteAction(), thus causing the servable
1303  // to remain in state kNew.)
1304 
1305  // Now, we'll un-aspire the servable, and then re-aspire it with a new loader.
1306  // The manager should get rid of the first loader, then bring up the second
1307  // one.
1308 
1309  std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1310  manager_->GetAspiredVersionsCallback()(kServableName,
1311  std::move(empty_aspired_versions));
1312  HandlePendingAspiredVersionsRequests();
1313 
1314  // Re-aspire the servable with a fresh loader.
1315  std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1316  test_util::MockLoader* second_loader = new NiceMock<test_util::MockLoader>();
1317  second_aspired_versions.push_back(
1318  {id, std::unique_ptr<Loader>(second_loader)});
1319  Notification second_load_called;
1320  EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{id}))
1321  .WillOnce(InvokeWithoutArgs([&]() {
1322  second_load_called.Notify();
1323  return OkStatus();
1324  }));
1325  manager_->GetAspiredVersionsCallback()(kServableName,
1326  std::move(second_aspired_versions));
1327  // The first HandlePendingAspiredVersionsRequests() call will do nothing,
1328  // because the first loader remains in the manager (with state kNew).
1329  HandlePendingAspiredVersionsRequests();
1330  // FlushServables() should remove the first loader, thus clearing the way for
1331  // a subsequent HandlePendingAspiredVersionsRequests() call to accept the
1332  // second loader.
1333  FlushServables();
1334  HandlePendingAspiredVersionsRequests();
1335  // Unload version 0 and load the new aspired version. Version 1 may or may not
1336  // be unloaded (depending on whether load/unload thread pools are used).
1337  for (int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1338  InvokePolicyAndExecuteAction();
1339  }
1340  second_load_called.WaitForNotification();
1341 }
1342 
1343 class MockAspiredVersionPolicy : public AspiredVersionPolicy {
1344  public:
1345  MOCK_METHOD(absl::optional<ServableAction>, GetNextAction,
1346  (const std::vector<AspiredServableStateSnapshot>&),
1347  (const, override));
1348 };
1349 
1350 TEST(AspiredVersionsManagerTest, CallPolicyWithAllVersions) {
1351  std::unique_ptr<AspiredVersionsManager> manager;
1352  AspiredVersionsManager::Options manager_options;
1353  MockAspiredVersionPolicy* policy = new MockAspiredVersionPolicy;
1354  // The state manager thread won't be run automatically.
1355  manager_options.manage_state_interval_micros = -1;
1356  manager_options.aspired_version_policy =
1357  std::unique_ptr<AspiredVersionPolicy>(policy);
1358  TF_CHECK_OK(
1359  AspiredVersionsManager::Create(std::move(manager_options), &manager));
1360  std::set<ServableId> servables;
1361  std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1362  for (int i = 0; i < kNumVersionsPerServable; ++i) {
1363  const ServableId id = {kServableName, i};
1364  aspired_versions.push_back(CreateAspiredVersion(id));
1365  servables.insert(id);
1366  }
1367  manager->GetAspiredVersionsCallback()(kServableName,
1368  std::move(aspired_versions));
1369  test_util::AspiredVersionsManagerTestAccess(manager.get())
1370  .HandlePendingAspiredVersionsRequests();
1371 
1372  std::vector<AspiredServableStateSnapshot> all_versions;
1373  EXPECT_CALL(*policy, GetNextAction(_))
1374  .WillOnce(Invoke(
1375  [&all_versions](
1376  const std::vector<AspiredServableStateSnapshot>& snapshots) {
1377  all_versions = snapshots;
1378  return absl::nullopt;
1379  }));
1380  test_util::AspiredVersionsManagerTestAccess(manager.get())
1381  .InvokePolicyAndExecuteAction();
1382  EXPECT_EQ(kNumVersionsPerServable, all_versions.size());
1383 }
1384 
1385 } // namespace
1386 } // namespace serving
1387 } // namespace tensorflow
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)