TensorFlow Serving C++ API Documentation
basic_manager_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/basic_manager.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <set>
24 #include <utility>
25 #include <vector>
26 
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/status/status.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/blocking_counter.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/null_file_system.h"
37 #include "tensorflow/core/protobuf/error_codes.pb.h"
38 #include "tensorflow_serving/core/servable_state_monitor.h"
39 #include "tensorflow_serving/core/test_util/availability_test_util.h"
40 #include "tensorflow_serving/core/test_util/fake_loader.h"
41 #include "tensorflow_serving/core/test_util/manager_test_util.h"
42 #include "tensorflow_serving/core/test_util/mock_loader.h"
43 #include "tensorflow_serving/util/any_ptr.h"
44 #include "tensorflow_serving/util/event_bus.h"
45 #include "tensorflow_serving/util/threadpool_executor.h"
46 
47 namespace tensorflow {
48 namespace serving {
49 namespace {
50 
51 using test_util::FakeLoader;
52 using test_util::WaitUntilServableManagerStateIsOneOf;
53 using ::testing::_;
54 using ::testing::AnyOf;
55 using ::testing::HasSubstr;
56 using ::testing::InSequence;
57 using ::testing::Invoke;
58 using ::testing::InvokeWithoutArgs;
59 using ::testing::MockFunction;
60 using ::testing::NiceMock;
61 using ::testing::Return;
62 using ::testing::UnorderedElementsAre;
63 using ::testing::UnorderedElementsAreArray;
64 
65 constexpr char kServableName[] = "kServableName";
66 constexpr char kServableName2[] = "kServableName2";
67 constexpr char kServableName3[] = "kServableName3";
68 
69 constexpr int kNumVersionsPerServable = 2;
70 
71 constexpr int kNumThreads = 10;
72 
73 MATCHER_P(EqualsServableState, servable_state, servable_state.DebugString()) {
74  if (arg == servable_state) {
75  return true;
76  }
77  *result_listener << arg.DebugString();
78  return false;
79 }
80 
81 // Creates a ServableData around a FakeLoader.
82 ServableData<std::unique_ptr<Loader>> CreateServable(
83  const ServableId& id, const Status load_status = OkStatus()) {
84  std::unique_ptr<Loader> loader(new FakeLoader(id.version, load_status));
85  return CreateServableData(id, std::move(loader));
86 }
87 
88 // We parameterize this test with the number of load & unload threads. (Zero
89 // means use an in-line executor instead of a thread pool.)
90 struct ThreadPoolSizes {
91  uint64_t num_load_threads;
92  uint64_t num_unload_threads;
93 };
94 class BasicManagerTest : public ::testing::TestWithParam<ThreadPoolSizes> {
95  protected:
96  BasicManagerTest()
97  : thread_pool_sizes_(GetParam()),
98  servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
99  servable_state_monitor_(servable_event_bus_.get()) {
100  BasicManager::Options options;
101  options.num_load_threads = thread_pool_sizes_.num_load_threads;
102  options.num_unload_threads = thread_pool_sizes_.num_unload_threads;
103  options.servable_event_bus = servable_event_bus_.get();
104  options.max_num_load_retries = 10;
105  options.load_retry_interval_micros = 0;
106  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
107  }
108 
109  void SetUp() override {
110  // We load the manager with two different servable streams, each with two
111  // versions 0 and 1.
112  std::set<ServableId> loaded_servables;
113  for (const char* servable_name : {kServableName, kServableName2}) {
114  for (int i = 1; i <= kNumVersionsPerServable; ++i) {
115  const ServableId id = {servable_name, i};
116  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
117  basic_manager_->LoadServable(
118  id, [](const Status& status) { TF_ASSERT_OK(status); });
119  loaded_servables.insert(id);
120  }
121  }
122  for (const ServableId& loaded_servable : loaded_servables) {
123  WaitUntilServableManagerStateIsOneOf(
124  servable_state_monitor_, loaded_servable,
125  {ServableState::ManagerState::kAvailable});
126  }
127  }
128 
129  ThreadPoolSizes thread_pool_sizes_;
130  std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
131  ServableStateMonitor servable_state_monitor_;
132  std::unique_ptr<BasicManager> basic_manager_;
133 };
134 
135 INSTANTIATE_TEST_CASE_P(
136  WithOrWithoutThreadPools, BasicManagerTest,
137  ::testing::Values(
138  ThreadPoolSizes{0, 0} /* without load or unload threadpools */,
139  ThreadPoolSizes{2, 0} /* with just a load threadpool */,
140  ThreadPoolSizes{0, 2} /* with just an unload threadpool */,
141  ThreadPoolSizes{4, 4} /* with load and unload threadpools */));
142 
143 TEST_P(BasicManagerTest, ServableHandleNotFoundMissingLoaderName) {
144  ServableHandle<int64_t> handle;
145  const Status status = basic_manager_->GetServableHandle(
146  ServableRequest::Latest(strings::StrCat(kServableName, "missing")),
147  &handle);
148  ASSERT_FALSE(status.ok()) << status;
149  EXPECT_EQ(error::NOT_FOUND, status.code());
150 }
151 
152 TEST_P(BasicManagerTest, ServableHandleNotFoundMissingVersion) {
153  // This version is missing.
154  const int64_t missing_version = 100;
155  ServableHandle<int64_t> handle;
156  const Status status = basic_manager_->GetServableHandle(
157  ServableRequest::Specific(kServableName, missing_version), &handle);
158  ASSERT_FALSE(status.ok()) << status;
159  EXPECT_EQ(error::NOT_FOUND, status.code());
160 }
161 
162 TEST_P(BasicManagerTest, ServableHandleEarliest) {
163  ASSERT_GT(kNumVersionsPerServable, 1);
164  ServableHandle<int64_t> handle;
165  const Status status = basic_manager_->GetServableHandle(
166  ServableRequest::Earliest(kServableName), &handle);
167  TF_ASSERT_OK(status);
168  EXPECT_EQ(1, *handle);
169 }
170 
171 TEST_P(BasicManagerTest, ServableHandleLatest) {
172  const ServableId id = {kServableName, kNumVersionsPerServable + 1};
173  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
174  basic_manager_->LoadServable(
175  id, [](const Status& status) { TF_ASSERT_OK(status); });
176  WaitUntilServableManagerStateIsOneOf(
177  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
178 
179  ServableHandle<int64_t> handle;
180  const Status status = basic_manager_->GetServableHandle(
181  ServableRequest::Latest(kServableName), &handle);
182  TF_ASSERT_OK(status);
183  EXPECT_EQ(kNumVersionsPerServable + 1, *handle);
184 }
185 
186 TEST_P(BasicManagerTest, AlreadyManagedError) {
187  const ServableId id = {"banana", 42};
188  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
189  EXPECT_FALSE(basic_manager_->ManageServable(CreateServable(id)).ok());
190 }
191 
192 // Tests the case where the latest version of a servable available is 0.
193 TEST_P(BasicManagerTest, ServableHandleLatestVersionIsZero) {
194  const ServableId id = {kServableName3, 1};
195  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
196  basic_manager_->LoadServable(
197  id, [](const Status& status) { TF_ASSERT_OK(status); });
198  WaitUntilServableManagerStateIsOneOf(
199  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
200 
201  ServableHandle<int64_t> handle;
202  const Status status = basic_manager_->GetServableHandle(
203  ServableRequest::Latest(kServableName3), &handle);
204  TF_ASSERT_OK(status);
205  EXPECT_EQ(1, *handle);
206  EXPECT_EQ(id, handle.id());
207 }
208 
209 TEST_P(BasicManagerTest, StopManagingUnknownId) {
210  const ServableId id = {kServableName3, 1};
211  EXPECT_FALSE(basic_manager_->StopManagingServable(id).ok());
212 }
213 
214 TEST_P(BasicManagerTest, StopManagingActiveServable) {
215  const ServableId id = {kServableName3, 1};
216  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
217  basic_manager_->LoadServable(
218  id, [](const Status& status) { TF_EXPECT_OK(status); });
219  WaitUntilServableManagerStateIsOneOf(
220  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
221  EXPECT_FALSE(basic_manager_->StopManagingServable(id).ok());
222 }
223 
224 TEST_P(BasicManagerTest, StopManagingDisabledServable) {
225  const ServableId id = {kServableName3, 1};
226  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
227  basic_manager_->LoadServable(
228  id, [](const Status& status) { TF_EXPECT_OK(status); });
229  WaitUntilServableManagerStateIsOneOf(
230  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
231  basic_manager_->UnloadServable(
232  id, [](const Status& status) { TF_EXPECT_OK(status); });
233  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
234  {ServableState::ManagerState::kEnd});
235  const absl::optional<ServableStateSnapshot<>> snapshot =
236  basic_manager_->GetManagedServableStateSnapshot(id);
237  EXPECT_EQ(LoaderHarness::State::kDisabled, snapshot->state);
238  const ServableState expected_state = {id, ServableState::ManagerState::kEnd,
239  OkStatus()};
240  EXPECT_THAT(*servable_state_monitor_.GetState(id),
241  EqualsServableState(expected_state));
242 
243  TF_ASSERT_OK(basic_manager_->StopManagingServable(id));
244  EXPECT_FALSE(basic_manager_->GetManagedServableStateSnapshot(id));
245 }
246 
247 TEST_P(BasicManagerTest, DontStopManagingOnError) {
248  const ServableId id = {kServableName, 7};
249  const Status error_status = errors::Internal("An error.");
250  std::unique_ptr<Loader> loader(new FakeLoader(7, error_status));
251  TF_ASSERT_OK(basic_manager_->ManageServable({id, std::move(loader)}));
252  basic_manager_->LoadServable(id, [error_status](const Status& status) {
253  EXPECT_EQ(error_status, status);
254  });
255  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
256  {ServableState::ManagerState::kEnd});
257  const absl::optional<ServableStateSnapshot<>> snapshot =
258  basic_manager_->GetManagedServableStateSnapshot(id);
259  EXPECT_EQ(LoaderHarness::State::kError, snapshot->state);
260  const ServableState expected_error_state = {
261  id, ServableState::ManagerState::kEnd, error_status};
262  EXPECT_THAT(*servable_state_monitor_.GetState(id),
263  EqualsServableState(expected_error_state));
264 }
265 
266 TEST_P(BasicManagerTest, ServableHandleSpecificVersion) {
267  ServableHandle<int64_t> handle;
268  const ServableId id = {kServableName2, 1};
269  const Status status =
270  basic_manager_->GetServableHandle(ServableRequest::FromId(id), &handle);
271  TF_ASSERT_OK(status);
272  EXPECT_EQ(1, *handle);
273  EXPECT_EQ(id, handle.id());
274 }
275 
276 // Tests an edge-case when the serving map is updated and the last version of a
277 // stream is not in kReady state.
278 TEST_P(BasicManagerTest, UpdateServingMapServableHandleLatest) {
279  // Using kServableName3 which doesn't have any servables loaded in the
280  // manager, as opposed to kServableName which already has 2 loaded.
281  const ServableId id0 = {kServableName3, 0};
282  // Servable is int64_t with value 0.
283  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
284  basic_manager_->LoadServable(
285  id0, [](const Status& status) { TF_ASSERT_OK(status); });
286  WaitUntilServableManagerStateIsOneOf(
287  servable_state_monitor_, id0, {ServableState::ManagerState::kAvailable});
288 
289  test_util::MockLoader* notify_to_unload = new NiceMock<test_util::MockLoader>;
290  // Don't make it const otherwise servable types will mismatch: const int64_t
291  // vs int64.
292  int64_t servable = 1;
293  ON_CALL(*notify_to_unload, servable())
294  .WillByDefault(Return(AnyPtr(&servable)));
295  ON_CALL(*notify_to_unload, EstimateResources(_))
296  .WillByDefault(Return(OkStatus()));
297  ON_CALL(*notify_to_unload, LoadWithMetadata(Loader::Metadata{id0}))
298  .WillByDefault(Return(OkStatus()));
299  const ServableId id1 = {kServableName3, 1};
300  TF_ASSERT_OK(basic_manager_->ManageServable(
301  {id1, std::unique_ptr<Loader>(notify_to_unload)}));
302  basic_manager_->LoadServable(
303  id1, [](const Status& status) { TF_ASSERT_OK(status); });
304  WaitUntilServableManagerStateIsOneOf(
305  servable_state_monitor_, id1, {ServableState::ManagerState::kAvailable});
306 
307  // We have loaded both versions 0 and 1 of kServableName3, so the latest
308  // handle should be that of v1.
309  {
310  ServableHandle<int64_t> handle;
311  const Status status = basic_manager_->GetServableHandle(
312  ServableRequest::Latest(kServableName3), &handle);
313  TF_ASSERT_OK(status);
314  EXPECT_EQ(id1, handle.id());
315  }
316 
317  // We will now try to unload v1, but we only allow it to move out from kReady
318  // state, and not complete the unload. Also, after it moves out from kReady,
319  // the serving map is also updated, so v0 would be the latest.
320  Notification unload_started;
321  Notification finish_unload;
322  EXPECT_CALL(*notify_to_unload, Unload()).WillOnce(Invoke([&]() {
323  unload_started.Notify();
324  finish_unload.WaitForNotification();
325  }));
326  Notification unload_finished;
327  std::unique_ptr<Thread> unload_last_servable(
328  Env::Default()->StartThread({}, "UnloadLastServable", [&]() {
329  basic_manager_->UnloadServable(id1, [&](const Status& status) {
330  TF_EXPECT_OK(status);
331  unload_finished.Notify();
332  });
333  }));
334  unload_started.WaitForNotification();
335 
336  // Servable map should just have {kServableName3, 0} at this point.
337  {
338  ServableHandle<int64_t> handle;
339  const Status status = basic_manager_->GetServableHandle(
340  ServableRequest::Latest(kServableName3), &handle);
341  TF_EXPECT_OK(status);
342  EXPECT_EQ(id0, handle.id());
343  }
344  finish_unload.Notify();
345  // We have to ensure that the unload has finished completely, otherwise the
346  // address of the notifications could be invalid in the load when we exit from
347  // this scope.
348  unload_finished.WaitForNotification();
349 }
350 
351 TEST_P(BasicManagerTest, ListAvailableServableIds) {
352  const std::vector<ServableId> expected_before = {{kServableName, 1},
353  {kServableName, 2},
354  {kServableName2, 1},
355  {kServableName2, 2}};
356  EXPECT_THAT(basic_manager_->ListAvailableServableIds(),
357  UnorderedElementsAreArray(expected_before));
358 
359  // Set stream kServableName to have servables 7 and unload 0 & 1, but 7 errors
360  // on load, so never moves to a loaded state.
361  const ServableId id = {kServableName, 7};
362  std::unique_ptr<Loader> loader(
363  new FakeLoader(7, errors::Internal("An error.")));
364  TF_ASSERT_OK(basic_manager_->ManageServable(
365  CreateServableData(id, std::move(loader))));
366  basic_manager_->LoadServable(id, [](const Status& status) {
367  EXPECT_EQ(errors::Internal("An error."), status);
368  });
369  basic_manager_->UnloadServable(
370  {kServableName, 1}, [](const Status& status) { TF_ASSERT_OK(status); });
371  basic_manager_->UnloadServable(
372  {kServableName, 2}, [](const Status& status) { TF_ASSERT_OK(status); });
373  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
374  {ServableState::ManagerState::kEnd});
375  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
376  {kServableName, 1},
377  {ServableState::ManagerState::kEnd});
378  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
379  {kServableName, 2},
380  {ServableState::ManagerState::kEnd});
381 
382  const std::vector<ServableId> expected_after = {{kServableName2, 1},
383  {kServableName2, 2}};
384  EXPECT_THAT(basic_manager_->ListAvailableServableIds(),
385  UnorderedElementsAreArray(expected_after));
386 }
387 
388 TEST_P(BasicManagerTest, GetAvailableServableHandles) {
389  // Scoped to destruct handles at the end of it.
390  {
391  const std::map<ServableId, ServableHandle<int64_t>> handles_before =
392  basic_manager_->GetAvailableServableHandles<int64_t>();
393  ASSERT_EQ(kNumVersionsPerServable * 2, handles_before.size());
394 
395  const std::vector<ServableId> expected_ids_before = {{kServableName, 1},
396  {kServableName, 2},
397  {kServableName2, 1},
398  {kServableName2, 2}};
399  for (const ServableId& expected_id : expected_ids_before) {
400  const auto found_it = handles_before.find(expected_id);
401  ASSERT_TRUE(found_it != handles_before.end());
402  EXPECT_EQ(expected_id.version, *found_it->second);
403  }
404  }
405 
406  // Set stream kServableName to have servables 7 and unload 0 & 1, but 7 errors
407  // on load, so never moves to a loaded state.
408  const ServableId id = {kServableName, 7};
409  std::unique_ptr<Loader> loader(
410  new FakeLoader(7, errors::Internal("An error.")));
411  TF_ASSERT_OK(basic_manager_->ManageServable(
412  CreateServableData(id, std::move(loader))));
413  basic_manager_->LoadServable(id, [](const Status& status) {
414  EXPECT_EQ(errors::Internal("An error."), status);
415  });
416  basic_manager_->UnloadServable(
417  {kServableName, 1}, [](const Status& status) { TF_ASSERT_OK(status); });
418  basic_manager_->UnloadServable(
419  {kServableName, 2}, [](const Status& status) { TF_ASSERT_OK(status); });
420  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
421  {ServableState::ManagerState::kEnd});
422  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
423  {kServableName, 1},
424  {ServableState::ManagerState::kEnd});
425  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
426  {kServableName, 2},
427  {ServableState::ManagerState::kEnd});
428 
429  {
430  const std::map<ServableId, ServableHandle<int64_t>> handles_after =
431  basic_manager_->GetAvailableServableHandles<int64_t>();
432  ASSERT_EQ(kNumVersionsPerServable, handles_after.size());
433 
434  const std::vector<ServableId> expected_ids_after = {{kServableName2, 1},
435  {kServableName2, 2}};
436  for (const ServableId& expected_id : expected_ids_after) {
437  const auto found_it = handles_after.find(expected_id);
438  ASSERT_TRUE(found_it != handles_after.end());
439  EXPECT_EQ(expected_id.version, *found_it->second);
440  }
441  }
442 }
443 
444 TEST_P(BasicManagerTest, GetAvailableServableHandlesWrongType) {
445  const std::map<ServableId, ServableHandle<int>> wrong_type_handles =
446  basic_manager_->GetAvailableServableHandles<int>();
447  EXPECT_EQ(0, wrong_type_handles.size());
448 }
449 
450 TEST_P(BasicManagerTest, GetManagedServableNames) {
451  EXPECT_THAT(basic_manager_->GetManagedServableNames(),
452  UnorderedElementsAre(kServableName, kServableName2));
453 }
454 
455 TEST_P(BasicManagerTest,
456  GetManagedServableStateSnapshotWithoutAdditionalState) {
457  const std::vector<ServableStateSnapshot<>> expected = {
458  {{kServableName, 1}, LoaderHarness::State::kReady, {}},
459  {{kServableName, 2}, LoaderHarness::State::kReady, {}}};
460  EXPECT_THAT(basic_manager_->GetManagedServableStateSnapshots(kServableName),
461  UnorderedElementsAreArray(expected));
462 }
463 
464 TEST_P(BasicManagerTest, GetManagedServableStateSnapshot) {
465  // Check servable state snapshot corresponding to a servable-id that is in
466  // ready state.
467  const ServableId id_ready = {kServableName, 1};
468  const absl::optional<ServableStateSnapshot<>> actual_ready_snapshot =
469  basic_manager_->GetManagedServableStateSnapshot(id_ready);
470  EXPECT_TRUE(actual_ready_snapshot);
471  const ServableStateSnapshot<> expected_ready_snapshot = {
472  id_ready, LoaderHarness::State::kReady, {}};
473  EXPECT_EQ(actual_ready_snapshot, expected_ready_snapshot);
474 
475  // Check servable state snapshot corresponding to a servable-id that is not
476  // managed by the basic-manager.
477  const ServableId id_notmanaged = {kServableName, 8};
478  EXPECT_FALSE(basic_manager_->GetManagedServableStateSnapshot(id_notmanaged));
479 }
480 
481 TEST_P(BasicManagerTest, GetManagedServableStateSnapshotsWithAdditionalState) {
482  TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
483  CreateServable({kServableName3, 0}), std::unique_ptr<int>(new int(0))));
484  TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
485  CreateServable({kServableName3, 1}), std::unique_ptr<int>(new int(1))));
486  const std::vector<ServableStateSnapshot<int>> expected = {
487  {{kServableName3, 0}, LoaderHarness::State::kNew, {0}},
488  {{kServableName3, 1}, LoaderHarness::State::kNew, {1}}};
489  EXPECT_THAT(
490  basic_manager_->GetManagedServableStateSnapshots<int>(kServableName3),
491  UnorderedElementsAreArray(expected));
492 }
493 
494 TEST_P(BasicManagerTest, MultipleManageCallsUsesFirstServable) {
495  const ServableId id = {kServableName, 1};
496 
497  // Servable 'id' is already managed, so further ManageServable() calls should
498  // fail (and not affect the status of the already-managed servable).
499  std::unique_ptr<Loader> first_ignored_loader(
500  new FakeLoader(1, errors::Internal("An error.")));
501  EXPECT_FALSE(basic_manager_
502  ->ManageServable(
503  CreateServableData(id, std::move(first_ignored_loader)))
504  .ok());
505 
506  // Same thing, but this time using a loader for a different servable version.
507  std::unique_ptr<Loader> second_ignored_loader(
508  new FakeLoader(2, errors::Internal("An error.")));
509  EXPECT_FALSE(basic_manager_
510  ->ManageServable(
511  CreateServableData(id, std::move(second_ignored_loader)))
512  .ok());
513 
514  ServableHandle<int64_t> handle;
515  TF_ASSERT_OK(basic_manager_->GetServableHandle(
516  ServableRequest::Specific(kServableName, 1), &handle));
517  EXPECT_EQ(1, *handle);
518 }
519 
520 // Tests to ensure the manager doesn't try to load or serve an incoming
521 // erroneous servable.
522 TEST_P(BasicManagerTest, ErroneousServable) {
523  const ServableId id = {kServableName, 3};
524  TF_ASSERT_OK(basic_manager_->ManageServable(
525  ServableData<std::unique_ptr<Loader>>(id, errors::Unknown("error"))));
526 
527  ServableHandle<int64_t> handle;
528  Status status = basic_manager_->GetServableHandle(
529  ServableRequest::Specific(kServableName, 3), &handle);
530  EXPECT_FALSE(status.ok()) << status;
531  basic_manager_->LoadServable(
532  id, [](const Status& status) { EXPECT_FALSE(status.ok()) << status; });
533 
534  status = basic_manager_->GetServableHandle(
535  ServableRequest::Specific(kServableName, 3), &handle);
536  EXPECT_FALSE(status.ok()) << status;
537 }
538 
539 // Tests to ensure that the deletion of a loader/servable occurs in a manager
540 // thread, and not a request thread.
541 TEST_P(BasicManagerTest, DestructOnNonServingThread) {
542  const ServableId id = {kServableName, 7};
543  TF_ASSERT_OK(basic_manager_->ManageServable(
544  CreateServableData(id, std::unique_ptr<Loader>(new FakeLoader(7)))));
545  basic_manager_->LoadServable(
546  id, [](const Status& status) { TF_ASSERT_OK(status); });
547  WaitUntilServableManagerStateIsOneOf(
548  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
549 
550  std::unique_ptr<ServableHandle<int64_t>> latest_handle(
551  new ServableHandle<int64_t>());
552  const Status status = basic_manager_->GetServableHandle(
553  ServableRequest::Latest(kServableName), latest_handle.get());
554  TF_ASSERT_OK(status);
555  EXPECT_EQ(7, **latest_handle);
556 
557  Notification done_unload_servable;
558  std::unique_ptr<Thread> unload_servable(
559  Env::Default()->StartThread({}, "UnloadServable", [&]() {
560  // Unload the servable.
561  basic_manager_->UnloadServable(
562  id, [](const Status& status) { TF_ASSERT_OK(status); });
563  WaitUntilServableManagerStateIsOneOf(
564  servable_state_monitor_, id, {ServableState::ManagerState::kEnd});
565  // TODO(b/35997855): Don't just ignore this status!
566  TF_ASSERT_OK(basic_manager_->StopManagingServable(id));
567  // The servable has been deleted in this thread if there is no
568  // thread-pool for load/unload.
569  if (thread_pool_sizes_.num_load_threads == 0) {
570  EXPECT_TRUE(FakeLoader::was_deleted_in_this_thread());
571  }
572  done_unload_servable.Notify();
573  }));
574 
575  // This will unblock the UnloadServable.
576  latest_handle.reset();
577  done_unload_servable.WaitForNotification();
578  // The servable wasn't deleted in this thread.
579  ASSERT_FALSE(FakeLoader::was_deleted_in_this_thread());
580 }
581 
582 TEST_P(BasicManagerTest, AdditionalState) {
583  const ServableId id = {kServableName, 3};
584  std::unique_ptr<int> state(new int(1));
585  TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
586  CreateServable(id), std::move(state)));
587 
588  EXPECT_EQ(1, *basic_manager_->GetAdditionalServableState<int>(id));
589  EXPECT_EQ(nullptr, basic_manager_->GetAdditionalServableState<float>(id));
590 }
591 
592 TEST_P(BasicManagerTest, NoAdditionalState) {
593  const ServableId id = {kServableName, 3};
594  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
595 
596  // Will return nullptr when there is no metadata set.
597  EXPECT_EQ(nullptr, basic_manager_->GetAdditionalServableState<int>(id));
598  EXPECT_EQ(nullptr, basic_manager_->GetAdditionalServableState<float>(id));
599 }
600 
601 TEST_P(BasicManagerTest, OutOfOrderLoadServable) {
602  const ServableId id = {kServableName, 3};
603  basic_manager_->LoadServable(id, [](const Status& status) {
604  EXPECT_FALSE(status.ok());
605  EXPECT_EQ(error::NOT_FOUND, status.code());
606  EXPECT_THAT(status.message(), HasSubstr("is not being managed"));
607  });
608 }
609 
610 TEST_P(BasicManagerTest, MultipleLoadServables) {
611  const ServableId id = {kServableName, 3};
612  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
613  basic_manager_->LoadServable(
614  id, [](const Status& status) { TF_ASSERT_OK(status); });
615  WaitUntilServableManagerStateIsOneOf(
616  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
617  basic_manager_->LoadServable(id, [](const Status& status) {
618  EXPECT_FALSE(status.ok());
619  EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
620  EXPECT_THAT(status.message(), HasSubstr("Duplicate load request"));
621  });
622 }
623 
624 TEST_P(BasicManagerTest, MultipleUnloadServables) {
625  const ServableId id = {kServableName, 3};
626  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
627  basic_manager_->LoadServable(
628  id, [](const Status& status) { TF_ASSERT_OK(status); });
629  WaitUntilServableManagerStateIsOneOf(
630  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
631  basic_manager_->UnloadServable(
632  id, [](const Status& status) { TF_ASSERT_OK(status); });
633  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
634  {ServableState::ManagerState::kEnd});
635  basic_manager_->UnloadServable(id, [](const Status& status) {
636  EXPECT_FALSE(status.ok());
637  EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
638  EXPECT_THAT(status.message(),
639  HasSubstr("unload already requested/ongoing"));
640  });
641 }
642 
643 TEST_P(BasicManagerTest, UnloadWithoutManage) {
644  const ServableId id = {kServableName, 3};
645  basic_manager_->UnloadServable(id, [](const Status& status) {
646  EXPECT_FALSE(status.ok());
647  EXPECT_EQ(error::NOT_FOUND, status.code());
648  EXPECT_THAT(status.message(), HasSubstr("is not being managed"));
649  });
650 }
651 
652 TEST_P(BasicManagerTest, UnloadWithoutLoad) {
653  const ServableId id = {kServableName, 3};
654  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
655  basic_manager_->UnloadServable(id, [](const Status& status) {
656  EXPECT_FALSE(status.ok());
657  EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
658  EXPECT_THAT(status.message(), HasSubstr("Servable not loaded"));
659  });
660 }
661 
662 TEST_P(BasicManagerTest, EventBusErroneousVersion) {
663  const ServableId id = {kServableName, 3};
664  TF_ASSERT_OK(basic_manager_->ManageServable(
665  ServableData<std::unique_ptr<Loader>>(id, errors::Unknown("error"))));
666 
667  const ServableState expected_published_state = {
668  id, ServableState::ManagerState::kEnd, errors::Unknown("error")};
669  EXPECT_THAT(*servable_state_monitor_.GetState(id),
670  EqualsServableState(expected_published_state));
671 }
672 
673 TEST_P(BasicManagerTest, EventBusErrorOnLoad) {
674  const ServableId id = {kServableName, 7};
675  std::unique_ptr<Loader> loader(
676  new FakeLoader(7, errors::Internal("Error on load.")));
677  TF_ASSERT_OK(basic_manager_->ManageServable({id, std::move(loader)}));
678 
679  const ServableState start_state = {id, ServableState::ManagerState::kStart,
680  OkStatus()};
681  EXPECT_THAT(*servable_state_monitor_.GetState(id),
682  EqualsServableState(start_state));
683 
684  basic_manager_->LoadServable(id, [](const Status& status) {});
685  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
686  {ServableState::ManagerState::kEnd});
687 
688  const ServableState error_state = {id, ServableState::ManagerState::kEnd,
689  errors::Internal("Error on load.")};
690  EXPECT_THAT(*servable_state_monitor_.GetState(id),
691  EqualsServableState(error_state));
692 }
693 
694 TEST_P(BasicManagerTest, EventBusServableLifecycle) {
695  const ServableId id = {kServableName, 7};
696  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
697  TF_ASSERT_OK(
698  basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
699 
700  const ServableState start_state = {id, ServableState::ManagerState::kStart,
701  OkStatus()};
702  EXPECT_THAT(*servable_state_monitor_.GetState(id),
703  EqualsServableState(start_state));
704 
705  Notification load_called;
706  Notification load_continue;
707  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
708  .WillOnce(InvokeWithoutArgs([&]() {
709  load_called.Notify();
710  load_continue.WaitForNotification();
711  return OkStatus();
712  }));
713 
714  std::unique_ptr<Thread> load_thread(
715  Env::Default()->StartThread(ThreadOptions(), "LoadThread", [&]() {
716  basic_manager_->LoadServable(id, [](const Status& status) {});
717  }));
718 
719  load_called.WaitForNotification();
720 
721  const ServableState loading_state = {
722  id, ServableState::ManagerState::kLoading, OkStatus()};
723  EXPECT_THAT(*servable_state_monitor_.GetState(id),
724  EqualsServableState(loading_state));
725 
726  load_continue.Notify();
727  WaitUntilServableManagerStateIsOneOf(
728  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
729 
730  const ServableState available_state = {
731  id, ServableState::ManagerState::kAvailable, OkStatus()};
732  EXPECT_THAT(*servable_state_monitor_.GetState(id),
733  EqualsServableState(available_state));
734 
735  Notification unload_called;
736  Notification unload_continue;
737  EXPECT_CALL(*loader, Unload()).WillOnce(Invoke([&]() {
738  unload_called.Notify();
739  unload_continue.WaitForNotification();
740  }));
741  // Scoped to ensure UnloadServable() is scheduled.
742  std::unique_ptr<Thread> unload_thread(
743  Env::Default()->StartThread(ThreadOptions(), "UnloadThread", [&]() {
744  basic_manager_->UnloadServable(id, [](const Status& status) {});
745  }));
746 
747  unload_called.WaitForNotification();
748 
749  const ServableState unloading_state = {
750  id, ServableState::ManagerState::kUnloading, OkStatus()};
751  EXPECT_THAT(*servable_state_monitor_.GetState(id),
752  EqualsServableState(unloading_state));
753 
754  unload_continue.Notify();
755  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
756  {ServableState::ManagerState::kEnd});
757 
758  const ServableState end_state = {id, ServableState::ManagerState::kEnd,
759  OkStatus()};
760  EXPECT_THAT(*servable_state_monitor_.GetState(id),
761  EqualsServableState(end_state));
762 }
763 
764 // Tests whether there are any errors if we don't have an event bus configured.
765 TEST_P(BasicManagerTest, NoEventBus) {
766  BasicManager::Options options;
767  // Single threaded execution.
768  options.num_load_threads = 0;
769  // No event bus.
770  options.servable_event_bus = nullptr;
771  std::unique_ptr<BasicManager> manager;
772  TF_ASSERT_OK(BasicManager::Create(std::move(options), &manager));
773 
774  const ServableId id = {kServableName, 7};
775  std::unique_ptr<Loader> loader(new FakeLoader(7));
776  TF_ASSERT_OK(manager->ManageServable({id, std::move(loader)}));
777  manager->LoadServable(id, [](const Status& status) { TF_ASSERT_OK(status); });
778  manager->UnloadServable(id,
779  [](const Status& status) { TF_ASSERT_OK(status); });
780 }
781 
782 TEST_P(BasicManagerTest, LoadsThenUnloads) {
783  std::set<ServableId> servables;
784  // Scoped so that all loads can be scheduled before proceeding.
785  {
786  ThreadPoolExecutor load_executor(Env::Default(), "LoadServables",
787  kNumThreads);
788  for (int i = 0; i < 20; ++i) {
789  const ServableId id = {kServableName3, i};
790  servables.insert(id);
791  load_executor.Schedule([this, id]() {
792  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
793  basic_manager_->LoadServable(
794  id, [](const Status& status) { TF_ASSERT_OK(status); });
795  });
796  }
797  }
798 
799  // At this point, all loads may not have completed, so we wait for them.
800  for (const ServableId& servable : servables) {
801  WaitUntilServableManagerStateIsOneOf(
802  servable_state_monitor_, servable,
803  {ServableState::ManagerState::kAvailable});
804  }
805 
806  {
807  ThreadPoolExecutor unload_executor(Env::Default(), "UnloadServables",
808  kNumThreads);
809  // Doing in reverse.
810  for (int i = 19; i >= 0; --i) {
811  unload_executor.Schedule([this, i]() {
812  const ServableId id = {kServableName3, i};
813  basic_manager_->UnloadServable(
814  id, [](const Status& status) { TF_ASSERT_OK(status); });
815  });
816  }
817  }
818 }
819 
820 TEST_P(BasicManagerTest, InterleavedLoadsAndUnloads) {
821  ThreadPoolExecutor executor(Env::Default(), "InterleavedLoadsAndUnloads",
822  kNumThreads);
823  for (int i = 0; i < 20; ++i) {
824  executor.Schedule([this, i]() {
825  const ServableId id = {kServableName3, i};
826  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
827  Notification load_done;
828  basic_manager_->LoadServable(id, [&load_done](const Status& status) {
829  TF_ASSERT_OK(status);
830  load_done.Notify();
831  });
832  load_done.WaitForNotification();
833  basic_manager_->UnloadServable(
834  id, [](const Status& status) { TF_ASSERT_OK(status); });
835  });
836  }
837 }
838 
839 class SetNumLoadThreadsBasicManagerTest : public ::testing::Test {
840  protected:
841  SetNumLoadThreadsBasicManagerTest() {
842  BasicManager::Options options;
843  options.num_load_threads = 0;
844  options.max_num_load_retries = 10;
845  options.load_retry_interval_micros = 0;
846  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
847  }
848 
849  std::unique_ptr<BasicManager> basic_manager_;
850 };
851 
852 TEST_F(SetNumLoadThreadsBasicManagerTest, ThreadPoolSwapped) {
853  test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
854  manager_test_access.SetNumLoadThreads(2);
855  EXPECT_EQ(2, manager_test_access.num_load_threads());
856 
857  const auto load_done_fn = [&](const Status& status) {
858  TF_ASSERT_OK(status);
859  // Tests whether the threadpools are actually swapped in
860  // SetNumLoadThreads().
861  static thread_local int per_thread_load_ctr = 0;
862  ++per_thread_load_ctr;
863  EXPECT_EQ(1, per_thread_load_ctr);
864  };
865 
866  const ServableId id0 = {kServableName3, 0};
867  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
868  basic_manager_->LoadServable(id0, load_done_fn);
869 
870  manager_test_access.SetNumLoadThreads(0);
871  EXPECT_EQ(0, manager_test_access.num_load_threads());
872 
873  const ServableId id1 = {kServableName3, 1};
874  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
875  basic_manager_->LoadServable(id1, load_done_fn);
876 
877  // Force the manager to finish before deleting the notifications.
878  basic_manager_.reset();
879 }
880 
881 TEST_F(SetNumLoadThreadsBasicManagerTest, ThreadPoolsNotAliveSimultaneously) {
882  test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
883  manager_test_access.SetNumLoadThreads(1);
884  EXPECT_EQ(1, manager_test_access.num_load_threads());
885 
886  std::set<string> data_race_set;
887  const auto data_race_fn = [&](const Status& status) {
888  // This line will cause a data race if both the loads happen simultaneously
889  // on different threads. This will be caught by the ThreadSanitizer, causing
890  // the test to fail.
891  data_race_set.insert("string");
892  };
893 
894  const ServableId id0 = {kServableName3, 0};
895  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
896  Notification notify_for_setting;
897  Notification continue_load;
898  basic_manager_->LoadServable(id0, [&](const Status& status) {
899  notify_for_setting.Notify();
900  continue_load.WaitForNotification();
901  data_race_fn(status);
902  });
903 
904  {
905  ThreadPoolExecutor executor(Env::Default(), "SetNumLoadThreads",
906  kNumThreads);
907  executor.Schedule([&]() {
908  notify_for_setting.WaitForNotification();
909  manager_test_access.SetNumLoadThreads(1);
910  EXPECT_EQ(1, manager_test_access.num_load_threads());
911  });
912 
913  executor.Schedule([&]() {
914  const ServableId id1 = {kServableName3, 1};
915  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
916  continue_load.Notify();
917  basic_manager_->LoadServable(
918  id1, [&](const Status& status) { data_race_fn(status); });
919  });
920  }
921 
922  // Force the manager to finish before deleting the notifications.
923  basic_manager_.reset();
924 }
925 
926 // Tests whether the fast-load scenario works. In the fast-load scenario we try
927 // to load a bunch of servables as fast as possible using a lot of threads.
928 TEST_F(SetNumLoadThreadsBasicManagerTest, FastLoad) {
929  test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
930  const uint32 prev_num_load_threads = manager_test_access.num_load_threads();
931  manager_test_access.SetNumLoadThreads(32);
932  EXPECT_EQ(32, manager_test_access.num_load_threads());
933 
934  {
935  ThreadPoolExecutor executor(Env::Default(), "FirstThreadPoolLoads",
936  kNumThreads);
937  for (int i = 0; i < 20; ++i) {
938  executor.Schedule([this, i]() {
939  const ServableId id = {kServableName3, i};
940  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
941  basic_manager_->LoadServable(
942  id, [](const Status& status) { TF_ASSERT_OK(status); });
943  // We don't wait for load to be done here because we want to test that
944  // SetNumLoadThreads() waits properly till all queued loads are
945  // finished. If a queued load hasn't been finished the corresponding
946  // UnloadServable() will fail.
947  });
948  }
949  }
950 
951  manager_test_access.SetNumLoadThreads(prev_num_load_threads);
952  EXPECT_EQ(prev_num_load_threads, manager_test_access.num_load_threads());
953 
954  {
955  ThreadPoolExecutor executor(Env::Default(), "Unloads", kNumThreads);
956  for (int i = 0; i < 20; ++i) {
957  executor.Schedule([this, i]() {
958  const ServableId id = {kServableName3, i};
959  basic_manager_->UnloadServable(
960  id, [](const Status& status) { TF_ASSERT_OK(status); });
961  });
962  }
963  }
964 }
965 
966 // This filesystem detects a call to FlushCaches(), which is triggered by the
967 // BasicManager's call to Env::Default()->FlushFileSystemCaches() after loading
968 // a servable.
969 class FlushDetectingFileSystem : public NullFileSystem {
970  public:
971  void FlushCaches() override { flushed = true; }
972  static std::atomic<bool> flushed;
973 };
974 
975 std::atomic<bool> FlushDetectingFileSystem::flushed;
976 
977 REGISTER_FILE_SYSTEM("flush", FlushDetectingFileSystem);
978 
979 // This test loads servables with BasicManager::Options::flush_filesystem_caches
980 // true or false, and verifies that filesystem caches were flushed (or not
981 // flushed) as expected.
982 class FlushFileSystemCachesTest : public ::testing::TestWithParam<bool> {
983  protected:
984  FlushFileSystemCachesTest() : flush_filesystem_caches_(GetParam()) {
985  BasicManager::Options options;
986  options.flush_filesystem_caches = flush_filesystem_caches_;
987  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
988  }
989 
990  std::unique_ptr<BasicManager> basic_manager_;
991  bool flush_filesystem_caches_;
992 };
993 
994 TEST_P(FlushFileSystemCachesTest, Load) {
995  test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
996  // The number of load threads is initially zero, so filesystems should be
997  // flushed if flush_filesystem_caches_ is true.
998  FlushDetectingFileSystem::flushed.store(false);
999  const ServableId id0 = {kServableName3, 0};
1000  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
1001  basic_manager_->LoadServable(id0, [&](const Status& status) {
1002  TF_ASSERT_OK(status);
1003  EXPECT_EQ(flush_filesystem_caches_,
1004  FlushDetectingFileSystem::flushed.load());
1005  });
1006  // Load another servable with two load threads. Filesystem caches should not
1007  // be flushed.
1008  manager_test_access.SetNumLoadThreads(2);
1009  FlushDetectingFileSystem::flushed.store(false);
1010  const ServableId id1 = {kServableName3, 1};
1011  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
1012  basic_manager_->LoadServable(id1, [&](const Status& status) {
1013  TF_ASSERT_OK(status);
1014  EXPECT_FALSE(FlushDetectingFileSystem::flushed.load());
1015  });
1016  // Now move to a single load thread and load a third servable. Filesystem
1017  // caches should once again be flushed if flush_filesystem_caches_ is true.
1018  manager_test_access.SetNumLoadThreads(1);
1019  FlushDetectingFileSystem::flushed.store(false);
1020  const ServableId id2 = {kServableName3, 2};
1021  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id2)));
1022  basic_manager_->LoadServable(id2, [&](const Status& status) {
1023  TF_ASSERT_OK(status);
1024  EXPECT_EQ(flush_filesystem_caches_,
1025  FlushDetectingFileSystem::flushed.load());
1026  });
1027  basic_manager_.reset();
1028 }
1029 
1030 INSTANTIATE_TEST_CASE_P(WithOrWithoutFlush, FlushFileSystemCachesTest,
1031  ::testing::Bool());
1032 
1033 TEST_P(BasicManagerTest, ConcurrentLoadsOnlyOneSucceeds) {
1034  const ServableId id = {kServableName3, 0};
1035  mutex status_mu;
1036  std::vector<Status> statuses(4);
1037  {
1038  ThreadPoolExecutor load_executor(Env::Default(), "LoadServables",
1039  kNumThreads);
1040  for (int i = 0; i < 4; ++i) {
1041  load_executor.Schedule([this, id, i, &statuses, &status_mu]() {
1042  // (Suppress a possible "this servable is already managed" error.)
1043  basic_manager_->ManageServable(CreateServable(id)).IgnoreError();
1044  basic_manager_->LoadServable(
1045  id, [i, &statuses, &status_mu](const Status& status) {
1046  mutex_lock l(status_mu);
1047  statuses[i] = status;
1048  });
1049  });
1050  }
1051  }
1052 
1053  // At this point, all loads may not have completed. Deleting BasicManager
1054  // would wait for all the scheduled loads to complete before deleting it.
1055  basic_manager_.reset();
1056 
1057  int num_status_ok = 0;
1058  for (int i = 0; i < 4; ++i) {
1059  mutex_lock l(status_mu);
1060  if (!statuses[i].ok()) {
1061  EXPECT_EQ(error::FAILED_PRECONDITION, statuses[i].code());
1062  EXPECT_THAT(statuses[i].message(), HasSubstr("Duplicate load request"));
1063  } else {
1064  ++num_status_ok;
1065  }
1066  }
1067  EXPECT_EQ(1, num_status_ok);
1068 }
1069 
1070 TEST_P(BasicManagerTest, ConcurrentUnloadsOnlyOneSucceeds) {
1071  const ServableId id = {kServableName3, 0};
1072  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id)));
1073  basic_manager_->LoadServable(
1074  id, [](const Status& status) { TF_ASSERT_OK(status); });
1075  // At this point, all loads may not have completed, so we wait for them.
1076  WaitUntilServableManagerStateIsOneOf(
1077  servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
1078 
1079  mutex status_mu;
1080  std::vector<Status> statuses(4);
1081  {
1082  ThreadPoolExecutor load_executor(Env::Default(), "LoadServables",
1083  kNumThreads);
1084  for (int i = 0; i < 4; ++i) {
1085  load_executor.Schedule([this, id, i, &statuses, &status_mu]() {
1086  basic_manager_->UnloadServable(
1087  id, [i, &statuses, &status_mu](const Status& status) {
1088  mutex_lock l(status_mu);
1089  statuses[i] = status;
1090  });
1091  });
1092  }
1093  }
1094 
1095  // At this point, all unloads may not have completed. Deleting BasicManager
1096  // would wait for all the scheduled unloads to complete before deleting it.
1097  basic_manager_.reset();
1098 
1099  int num_status_ok = 0;
1100  for (int i = 0; i < 4; ++i) {
1101  mutex_lock l(status_mu);
1102  // The error can be either of 2.
1103  if (!statuses[i].ok()) {
1104  ASSERT_THAT(statuses[i].code(),
1105  AnyOf(error::NOT_FOUND, error::FAILED_PRECONDITION));
1106  if (statuses[i].code() == error::NOT_FOUND) {
1107  EXPECT_THAT(statuses[i].message(), HasSubstr("not being managed"));
1108  } else {
1109  EXPECT_THAT(statuses[i].message(),
1110  HasSubstr("unload already requested/ongoing"));
1111  }
1112  } else {
1113  ++num_status_ok;
1114  }
1115  }
1116  EXPECT_EQ(1, num_status_ok);
1117 }
1118 
1119 TEST_P(BasicManagerTest, RetryOnLoadErrorFinallySucceeds) {
1120  const ServableId id = {kServableName, 7};
1121  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
1122  TF_ASSERT_OK(
1123  basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1124  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1125  .WillOnce(Return(errors::Internal("Load error.")))
1126  .WillRepeatedly(Return(OkStatus()));
1127  basic_manager_->LoadServable(
1128  id, [](const Status& status) { TF_ASSERT_OK(status); });
1129 }
1130 
1131 TEST_P(BasicManagerTest, RetryOnLoadErrorFinallyFails) {
1132  const ServableId id = {kServableName, 7};
1133  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
1134  TF_ASSERT_OK(
1135  basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1136  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1137  .WillRepeatedly(Return(errors::Internal("Load error.")));
1138  basic_manager_->LoadServable(id, [](const Status& status) {
1139  EXPECT_EQ(errors::Internal("Load error."), status);
1140  });
1141 }
1142 
1143 // Tests cancelling load retries.
1144 TEST_P(BasicManagerTest, RetryOnLoadErrorCancelledLoad) {
1145  const ServableId id = {kServableName, 7};
1146  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
1147  TF_ASSERT_OK(
1148  basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1149 
1150  Notification load_called;
1151  Notification load_should_return;
1152  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1153  .WillOnce(InvokeWithoutArgs([&load_called, &load_should_return]() {
1154  load_called.Notify();
1155  load_should_return.WaitForNotification();
1156  return errors::Internal("Load error.");
1157  }))
1158  .WillRepeatedly(Return(OkStatus()));
1159  std::unique_ptr<Thread> load_thread(
1160  Env::Default()->StartThread(ThreadOptions(), "LoadServable", [&]() {
1161  basic_manager_->LoadServable(id, [](const Status& status) {
1162  EXPECT_EQ(errors::Internal("Load error."), status);
1163  });
1164  }));
1165  load_called.WaitForNotification();
1166  basic_manager_->CancelLoadServableRetry(id);
1167  load_should_return.Notify();
1168  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1169  {ServableState::ManagerState::kEnd});
1170 }
1171 
1172 TEST_P(BasicManagerTest, LoadAfterCancelledLoad) {
1173  const ServableId id = {kServableName, 7};
1174  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
1175  TF_ASSERT_OK(
1176  basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1177 
1178  Notification load_called;
1179  Notification load_should_return;
1180  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1181  .WillOnce(InvokeWithoutArgs([&load_called, &load_should_return]() {
1182  load_called.Notify();
1183  load_should_return.WaitForNotification();
1184  return errors::Internal("Load error.");
1185  }))
1186  .WillRepeatedly(Return(OkStatus()));
1187 
1188  std::unique_ptr<Thread> load_thread(
1189  Env::Default()->StartThread(ThreadOptions(), "LoadServable", [&]() {
1190  basic_manager_->LoadServable(id, [](const Status& status) {
1191  EXPECT_EQ(errors::Internal("Load error."), status);
1192  });
1193  }));
1194  load_called.WaitForNotification();
1195  basic_manager_->CancelLoadServableRetry(id);
1196  load_should_return.Notify();
1197  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1198  {ServableState::ManagerState::kEnd});
1199 
1200  basic_manager_->LoadServable(
1201  id, [](const Status& status) { EXPECT_FALSE(status.ok()) << status; });
1202 }
1203 
1204 TEST(NonParameterizedBasicManagerTest, PreLoadHook) {
1205  BasicManager::Options options;
1206  // Single threaded execution.
1207  options.num_load_threads = 0;
1208  // No event bus.
1209  options.servable_event_bus = nullptr;
1210  MockFunction<void(const ServableId&)> mock_pre_load_hook;
1211  options.pre_load_hook = mock_pre_load_hook.AsStdFunction();
1212  std::unique_ptr<BasicManager> manager;
1213  TF_ASSERT_OK(BasicManager::Create(std::move(options), &manager));
1214 
1215  const ServableId id = {kServableName, 7};
1216  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>();
1217  TF_ASSERT_OK(manager->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1218 
1219  bool pre_load_hook_called = false;
1220  EXPECT_CALL(mock_pre_load_hook, Call(id)).WillOnce(InvokeWithoutArgs([&]() {
1221  pre_load_hook_called = true;
1222  }));
1223  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1224  .WillOnce(InvokeWithoutArgs([&]() {
1225  EXPECT_TRUE(pre_load_hook_called);
1226  return OkStatus();
1227  }));
1228  manager->LoadServable(id, [](const Status& status) { TF_ASSERT_OK(status); });
1229  manager->UnloadServable(id,
1230  [](const Status& status) { TF_ASSERT_OK(status); });
1231 }
1232 
1233 // Creates a ResourceAllocation proto with 'quantity' units of RAM.
1234 ResourceAllocation CreateResourceQuantity(const int quantity) {
1235  ResourceAllocation allocation;
1236  auto* ram_resource = allocation.add_resource_quantities();
1237  ram_resource->mutable_resource()->set_device("main");
1238  ram_resource->mutable_resource()->set_kind("ram");
1239  ram_resource->set_quantity(quantity);
1240  return allocation;
1241 }
1242 
1243 // Creates a resource tracker that deals with just a single resource (RAM) and
1244 // initially has 'total_ram_resources' quantity of that resource.
1245 std::unique_ptr<ResourceTracker> CreateSimpleResourceTracker(
1246  const int resource_quantity) {
1247  std::unique_ptr<ResourceUtil> util(new ResourceUtil({{{"main", 1}}}));
1248  std::unique_ptr<ResourceTracker> tracker;
1249  TF_CHECK_OK(ResourceTracker::Create(CreateResourceQuantity(resource_quantity),
1250  std::move(util), &tracker));
1251  return tracker;
1252 }
1253 
1254 class ResourceConstrainedBasicManagerTest : public ::testing::Test {
1255  protected:
1256  ResourceConstrainedBasicManagerTest()
1257  : servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
1258  servable_state_monitor_(servable_event_bus_.get()) {
1259  BasicManager::Options options;
1260  // Seed the manager with ten resource units.
1261  options.resource_tracker = CreateSimpleResourceTracker(10);
1262  options.servable_event_bus = servable_event_bus_.get();
1263  // Allow up to two loads and two unloads to be processed concurrently.
1264  options.num_load_threads = 2;
1265  options.num_unload_threads = 2;
1266  // We don't want retries.
1267  options.max_num_load_retries = 0;
1268  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
1269  }
1270 
1271  std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
1272  ServableStateMonitor servable_state_monitor_;
1273  std::unique_ptr<BasicManager> basic_manager_;
1274 };
1275 
1276 // A loader whose Load() method calls into a blocking counter. It requires 5
1277 // resource units, i.e. half of the total system resources.
1278 class BarrierLoader : public Loader {
1279  public:
1280  explicit BarrierLoader(BlockingCounter* counter) : counter_(counter) {}
1281  ~BarrierLoader() override = default;
1282 
1283  Status EstimateResources(ResourceAllocation* estimate) const override {
1284  *estimate = CreateResourceQuantity(5);
1285  return OkStatus();
1286  }
1287 
1288  Status Load() override {
1289  counter_->DecrementCount();
1290  counter_->Wait();
1291  return OkStatus();
1292  }
1293 
1294  void Unload() override {}
1295 
1296  AnyPtr servable() override { return AnyPtr(); }
1297 
1298  private:
1299  BlockingCounter* const counter_;
1300 
1301  TF_DISALLOW_COPY_AND_ASSIGN(BarrierLoader);
1302 };
1303 
1304 TEST_F(ResourceConstrainedBasicManagerTest, ConcurrentLoads) {
1305  // Two loads that each require half the system resources should be handled
1306  // concurrently (i.e. the manager should not serialize them needlessly).
1307  // BarrierLoader verifies that the Load() calls occur concurrently.
1308  int kNumLoaders = 2;
1309  BlockingCounter barrier(kNumLoaders);
1310  for (int i = 0; i < kNumLoaders; ++i) {
1311  std::unique_ptr<Loader> loader(new BarrierLoader(&barrier));
1312  const ServableId id = {"barrier", i};
1313  TF_ASSERT_OK(basic_manager_->ManageServable(
1314  CreateServableData(id, std::move(loader))));
1315  basic_manager_->LoadServable(
1316  id, [](const Status& status) { TF_EXPECT_OK(status); });
1317  }
1318  // Force the manager to finish before deleting 'barrier'.
1319  basic_manager_.reset();
1320 }
1321 
1322 TEST_F(ResourceConstrainedBasicManagerTest, InsufficientResources) {
1323  // A first loader that succeeds and consumes all of the serving system's
1324  // resources.
1325  const ServableId hogging_id = {"hogging", 0};
1326  test_util::MockLoader* hogging_loader = new NiceMock<test_util::MockLoader>;
1327  ON_CALL(*hogging_loader, EstimateResources(_))
1328  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1329  *estimate = CreateResourceQuantity(10 /* = total system resources */);
1330  return OkStatus();
1331  }));
1332  EXPECT_CALL(*hogging_loader, LoadWithMetadata(Loader::Metadata{hogging_id}))
1333  .WillOnce(Return(OkStatus()));
1334  TF_ASSERT_OK(basic_manager_->ManageServable(
1335  CreateServableData(hogging_id, std::unique_ptr<Loader>(hogging_loader))));
1336  Notification hogging_loaded;
1337  basic_manager_->LoadServable(hogging_id,
1338  [&hogging_loaded](const Status& status) {
1339  TF_EXPECT_OK(status);
1340  hogging_loaded.Notify();
1341  });
1342  hogging_loaded.WaitForNotification();
1343 
1344  // A second loader that gets rejected due to insufficient resources.
1345  const ServableId rejected_id = {"rejected", 0};
1346  test_util::MockLoader* rejected_loader = new NiceMock<test_util::MockLoader>;
1347  ON_CALL(*rejected_loader, EstimateResources(_))
1348  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1349  *estimate = CreateResourceQuantity(1);
1350  return OkStatus();
1351  }));
1352  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1353  rejected_id, std::unique_ptr<Loader>(rejected_loader))));
1354  Notification rejection_received;
1355  Status rejected_status;
1356  basic_manager_->LoadServable(
1357  rejected_id,
1358  [&rejection_received, &rejected_status](const Status& status) {
1359  ASSERT_FALSE(status.ok());
1360  ASSERT_EQ(error::RESOURCE_EXHAUSTED, status.code());
1361  rejected_status = status;
1362  rejection_received.Notify();
1363  });
1364  rejection_received.WaitForNotification();
1365  const ServableState expected_error_state = {
1366  rejected_id, ServableState::ManagerState::kEnd, rejected_status};
1367  EXPECT_THAT(*servable_state_monitor_.GetState(rejected_id),
1368  EqualsServableState(expected_error_state));
1369 
1370  // Make sure we're still managing the rejected servable.
1371  const absl::optional<ServableStateSnapshot<>> snapshot =
1372  basic_manager_->GetManagedServableStateSnapshot(rejected_id);
1373  EXPECT_EQ(LoaderHarness::State::kError, snapshot->state);
1374 }
1375 
1376 TEST_F(ResourceConstrainedBasicManagerTest, ResourcesReleasedIfLoadFails) {
1377  // A first loader that fails. Its resource reservation should get released.
1378  const ServableId failing_id = {"failing", 0};
1379  test_util::MockLoader* failing_loader = new NiceMock<test_util::MockLoader>;
1380  ON_CALL(*failing_loader, EstimateResources(_))
1381  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1382  *estimate = CreateResourceQuantity(10);
1383  return OkStatus();
1384  }));
1385  EXPECT_CALL(*failing_loader, LoadWithMetadata(Loader::Metadata{failing_id}))
1386  .WillOnce(Return(errors::Unknown("Load failure")));
1387  TF_ASSERT_OK(basic_manager_->ManageServable(
1388  CreateServableData(failing_id, std::unique_ptr<Loader>(failing_loader))));
1389  Notification failing_failed;
1390  basic_manager_->LoadServable(failing_id,
1391  [&failing_failed](const Status& status) {
1392  EXPECT_FALSE(status.ok());
1393  failing_failed.Notify();
1394  });
1395  failing_failed.WaitForNotification();
1396 
1397  // A second loader that succeeds. The failure of the first loader should
1398  // enable this one to get loaded (versus rejection with a resource exhaustion
1399  // error).
1400  const ServableId succeeding_id = {"succeeding", 0};
1401  test_util::MockLoader* succeeding_loader =
1402  new NiceMock<test_util::MockLoader>;
1403  ON_CALL(*succeeding_loader, EstimateResources(_))
1404  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1405  *estimate = CreateResourceQuantity(10);
1406  return OkStatus();
1407  }));
1408  EXPECT_CALL(*succeeding_loader,
1409  LoadWithMetadata(Loader::Metadata{succeeding_id}))
1410  .WillOnce(Return(OkStatus()));
1411  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1412  succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1413  basic_manager_->LoadServable(
1414  succeeding_id, [](const Status& status) { TF_EXPECT_OK(status); });
1415 }
1416 
1417 TEST_F(ResourceConstrainedBasicManagerTest,
1418  ResourcesReleasedIfEstimateDecreasesAfterLoad) {
1419  // A first loader that succeeds and then lowers its resource estimate.
1420  const ServableId overestimating_id = {"overestimating", 0};
1421  test_util::MockLoader* overestimating_loader =
1422  new NiceMock<test_util::MockLoader>;
1423  {
1424  InSequence sequence;
1425  EXPECT_CALL(*overestimating_loader, EstimateResources(_))
1426  .WillOnce(Invoke([](ResourceAllocation* estimate) {
1427  *estimate = CreateResourceQuantity(10);
1428  return OkStatus();
1429  }))
1430  .RetiresOnSaturation();
1431  EXPECT_CALL(*overestimating_loader,
1432  LoadWithMetadata(Loader::Metadata{overestimating_id}))
1433  .WillOnce(Return(OkStatus()));
1434  EXPECT_CALL(*overestimating_loader, EstimateResources(_))
1435  .WillOnce(Invoke([](ResourceAllocation* estimate) {
1436  *estimate = CreateResourceQuantity(5 /* lower estimate after load */);
1437  return OkStatus();
1438  }))
1439  .RetiresOnSaturation();
1440  }
1441  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1442  overestimating_id, std::unique_ptr<Loader>(overestimating_loader))));
1443  Notification overestimating_loaded;
1444  basic_manager_->LoadServable(overestimating_id,
1445  [&overestimating_loaded](const Status& status) {
1446  TF_EXPECT_OK(status);
1447  overestimating_loaded.Notify();
1448  });
1449  overestimating_loaded.WaitForNotification();
1450 
1451  // A second loader that succeeds. The re-estimation of the first loader should
1452  // enable this one to get loaded (versus rejection with a resource exhaustion
1453  // error).
1454  const ServableId succeeding_id = {"succeeding", 0};
1455  test_util::MockLoader* succeeding_loader =
1456  new NiceMock<test_util::MockLoader>;
1457  ON_CALL(*succeeding_loader, EstimateResources(_))
1458  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1459  *estimate = CreateResourceQuantity(5);
1460  return OkStatus();
1461  }));
1462  EXPECT_CALL(*succeeding_loader,
1463  LoadWithMetadata(Loader::Metadata{succeeding_id}))
1464  .WillOnce(Return(OkStatus()));
1465  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1466  succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1467  basic_manager_->LoadServable(
1468  succeeding_id, [](const Status& status) { TF_EXPECT_OK(status); });
1469 }
1470 
1471 TEST_F(ResourceConstrainedBasicManagerTest, ResourcesReleasedAfterUnload) {
1472  const ServableId unloading_id = {"unloading", 0};
1473  test_util::MockLoader* unloading_loader = new NiceMock<test_util::MockLoader>;
1474  ON_CALL(*unloading_loader, EstimateResources(_))
1475  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1476  *estimate = CreateResourceQuantity(10);
1477  return OkStatus();
1478  }));
1479  Notification load_done;
1480  EXPECT_CALL(*unloading_loader,
1481  LoadWithMetadata(Loader::Metadata{unloading_id}))
1482  .WillOnce(Return(OkStatus()));
1483  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1484  unloading_id, std::unique_ptr<Loader>(unloading_loader))));
1485  basic_manager_->LoadServable(unloading_id,
1486  [&load_done](const Status& status) {
1487  TF_EXPECT_OK(status);
1488  load_done.Notify();
1489  });
1490  load_done.WaitForNotification();
1491  Notification unload_started;
1492  Notification finish_unload;
1493  EXPECT_CALL(*unloading_loader, Unload())
1494  .WillOnce(Invoke([&unload_started, &finish_unload] {
1495  unload_started.Notify();
1496  finish_unload.WaitForNotification();
1497  }));
1498  basic_manager_->UnloadServable(
1499  unloading_id, [](const Status& status) { TF_EXPECT_OK(status); });
1500  unload_started.WaitForNotification();
1501 
1502  // A second loader that succeeds. The unloading of the first loader should
1503  // enable this one to get loaded (versus rejection with a resource exhaustion
1504  // error).
1505  const ServableId succeeding_id = {"succeeding", 0};
1506  test_util::MockLoader* succeeding_loader =
1507  new NiceMock<test_util::MockLoader>;
1508  EXPECT_CALL(*succeeding_loader, EstimateResources(_))
1509  .WillOnce(Invoke([&finish_unload](ResourceAllocation* estimate) {
1510  finish_unload.Notify();
1511  *estimate = CreateResourceQuantity(10);
1512  return OkStatus();
1513  }))
1514  .WillOnce(Invoke([](ResourceAllocation* estimate) {
1515  *estimate = CreateResourceQuantity(10);
1516  return OkStatus();
1517  }));
1518  EXPECT_CALL(*succeeding_loader,
1519  LoadWithMetadata(Loader::Metadata{succeeding_id}))
1520  .WillOnce(Return(OkStatus()));
1521  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1522  succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1523  basic_manager_->LoadServable(
1524  succeeding_id, [](const Status& status) { TF_EXPECT_OK(status); });
1525 
1526  // Force the manager to finish before deleting the notifications.
1527  basic_manager_.reset();
1528 }
1529 
1530 TEST_F(ResourceConstrainedBasicManagerTest, FirstLoadDeniedSecondOneApproved) {
1531  // A first loader that gets rejected due to insufficient resources.
1532  const ServableId denied_id = {"denied", 0};
1533  test_util::MockLoader* denied_loader = new NiceMock<test_util::MockLoader>;
1534  Notification denied_estimate_started;
1535  Notification finish_denied_estimate;
1536  EXPECT_CALL(*denied_loader, EstimateResources(_))
1537  .WillOnce(Invoke([&denied_estimate_started,
1538  &finish_denied_estimate](ResourceAllocation* estimate) {
1539  denied_estimate_started.Notify();
1540  finish_denied_estimate.WaitForNotification();
1541  *estimate = CreateResourceQuantity(11 /* more than the system total */);
1542  return OkStatus();
1543  }));
1544  // Load won't be called because resources are not enough to load it.
1545  EXPECT_CALL(*denied_loader, LoadWithMetadata(Loader::Metadata{denied_id}))
1546  .Times(0);
1547  TF_ASSERT_OK(basic_manager_->ManageServable(
1548  CreateServableData(denied_id, std::unique_ptr<Loader>(denied_loader))));
1549 
1550  // A second loader that succeeds.
1551  const ServableId succeeding_id = {"succeeding", 0};
1552  test_util::MockLoader* succeeding_loader =
1553  new NiceMock<test_util::MockLoader>;
1554  ON_CALL(*succeeding_loader, EstimateResources(_))
1555  .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1556  *estimate = CreateResourceQuantity(10);
1557  return OkStatus();
1558  }));
1559  TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1560  succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1561 
1562  Status denied_load_status;
1563  // Place the first servable into a load request decision phase.
1564  basic_manager_->LoadServable(
1565  denied_id, [&denied_load_status](const Status& status) {
1566  denied_load_status = status;
1567  ASSERT_FALSE(status.ok());
1568  EXPECT_EQ(error::RESOURCE_EXHAUSTED, status.code());
1569  });
1570  denied_estimate_started.WaitForNotification();
1571  // The second servable's Load() call shouldn't occur until after the first
1572  // servable's load request exits its decision phase.
1573  EXPECT_CALL(*succeeding_loader,
1574  LoadWithMetadata(Loader::Metadata{succeeding_id}))
1575  .WillOnce(InvokeWithoutArgs([&finish_denied_estimate]() {
1576  // Ensure that the first servable's load request has been given
1577  // permission to exit its decision phase.
1578  EXPECT_TRUE(finish_denied_estimate.HasBeenNotified());
1579  return OkStatus();
1580  }));
1581 
1582  // Scoping ensures that the thread is run by the end of this scope.
1583  {
1584  // Have to run this in a thread otherwise we enter a deadlock because
1585  // LoadServable() locks a mutex which is already locked by the denied
1586  // servable's decision phase, and is waiting for finish_denied_estimate to
1587  // be notified.
1588  std::unique_ptr<Thread> load_servable(
1589  Env::Default()->StartThread({}, "LoadServable", [&]() {
1590  basic_manager_->LoadServable(succeeding_id, [](const Status& status) {
1591  TF_EXPECT_OK(status);
1592  });
1593  }));
1594 
1595  finish_denied_estimate.Notify();
1596  }
1597 
1598  // Force the manager to finish before deleting the notifications.
1599  basic_manager_.reset();
1600 
1601  const ServableState expected_error_state = {
1602  denied_id, ServableState::ManagerState::kEnd, denied_load_status};
1603  EXPECT_THAT(*servable_state_monitor_.GetState(denied_id),
1604  EqualsServableState(expected_error_state));
1605 }
1606 
1607 TEST_F(ResourceConstrainedBasicManagerTest, EventBusErrorOnEstimateResources) {
1608  const ServableId id = {kServableName, 7};
1609  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1610  EXPECT_CALL(*loader, EstimateResources(_))
1611  .WillOnce(Return(errors::Internal("Error on estimate resources.")));
1612  TF_ASSERT_OK(basic_manager_->ManageServable(
1613  CreateServableData(id, std::unique_ptr<Loader>(loader))));
1614  basic_manager_->LoadServable(
1615  id, [](const Status& status) { EXPECT_FALSE(status.ok()); });
1616  WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
1617  {ServableState::ManagerState::kEnd});
1618  const ServableState error_state = {
1619  id, ServableState::ManagerState::kEnd,
1620  errors::Internal(strings::StrCat(
1621  "Error while attempting to reserve resources to load servable ",
1622  id.DebugString(), ": Error on estimate resources."))};
1623  EXPECT_THAT(*servable_state_monitor_.GetState(id),
1624  EqualsServableState(error_state));
1625 }
1626 
1627 TEST(EstimateResourcesRetriedTest, Succeeds) {
1628  std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1630  ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1631 
1632  BasicManager::Options options;
1633  // Seed the manager with ten resource units.
1634  options.resource_tracker = CreateSimpleResourceTracker(10);
1635  options.servable_event_bus = servable_event_bus.get();
1636  options.num_load_threads = 0;
1637  options.num_unload_threads = 0;
1638 
1639  options.max_num_load_retries = 1;
1640  options.load_retry_interval_micros = 0;
1641 
1642  std::unique_ptr<BasicManager> basic_manager;
1643  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1644 
1645  const ServableId id = {kServableName, 7};
1646  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1647  EXPECT_CALL(*loader, EstimateResources(_))
1648  .WillOnce(Return(errors::Internal("Error on estimate resources.")))
1649  .WillOnce(Return(OkStatus()));
1650  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{id}))
1651  .WillRepeatedly(Return(OkStatus()));
1652  TF_ASSERT_OK(basic_manager->ManageServable(
1653  CreateServableData(id, std::unique_ptr<Loader>(loader))));
1654  basic_manager->LoadServable(
1655  id, [](const Status& status) { EXPECT_TRUE(status.ok()); });
1656  WaitUntilServableManagerStateIsOneOf(
1657  servable_state_monitor, id, {ServableState::ManagerState::kAvailable});
1658  const ServableState available_state = {
1659  id, ServableState::ManagerState::kAvailable, OkStatus()};
1660  EXPECT_THAT(*servable_state_monitor.GetState(id),
1661  EqualsServableState(available_state));
1662 }
1663 
1664 TEST(EstimateResourcesRetriedTest, Fails) {
1665  std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1667  ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1668 
1669  BasicManager::Options options;
1670  // Seed the manager with ten resource units.
1671  options.resource_tracker = CreateSimpleResourceTracker(10);
1672  options.servable_event_bus = servable_event_bus.get();
1673  options.num_load_threads = 0;
1674  options.num_unload_threads = 0;
1675 
1676  options.max_num_load_retries = 1;
1677  options.load_retry_interval_micros = 0;
1678 
1679  std::unique_ptr<BasicManager> basic_manager;
1680  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1681 
1682  const ServableId id = {kServableName, 7};
1683  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1684  EXPECT_CALL(*loader, EstimateResources(_))
1685  .WillOnce(Return(errors::Internal("Error on estimate resources.")))
1686  .WillOnce(Return(errors::Internal("Error on estimate resources.")))
1687  .WillRepeatedly(Return(OkStatus()));
1688  TF_ASSERT_OK(basic_manager->ManageServable(
1689  CreateServableData(id, std::unique_ptr<Loader>(loader))));
1690  basic_manager->LoadServable(
1691  id, [](const Status& status) { EXPECT_FALSE(status.ok()); });
1692  WaitUntilServableManagerStateIsOneOf(servable_state_monitor, id,
1693  {ServableState::ManagerState::kEnd});
1694  EXPECT_FALSE(servable_state_monitor.GetState(id)->health.ok());
1695 }
1696 
1697 TEST(EstimateResourcesRetriedTest, NonRetriableError) {
1698  std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1700  ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1701 
1702  BasicManager::Options options;
1703  // Seed the manager with ten resource units.
1704  options.resource_tracker = CreateSimpleResourceTracker(10);
1705  options.servable_event_bus = servable_event_bus.get();
1706  options.num_load_threads = 0;
1707  options.num_unload_threads = 0;
1708  options.should_retry_model_load =
1709  ([](absl::Status status) { return !absl::IsInvalidArgument(status); });
1710 
1711  options.max_num_load_retries = 10;
1712  options.load_retry_interval_micros = 100000000;
1713 
1714  std::unique_ptr<BasicManager> basic_manager;
1715  TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1716 
1717  const ServableId id = {kServableName, 7};
1718  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
1719  EXPECT_CALL(*loader, LoadWithMetadata(_))
1720  .WillOnce(Return(errors::InvalidArgument("Non-retriable error.")))
1721  .WillRepeatedly(Return(absl::OkStatus()));
1722  TF_ASSERT_OK(basic_manager->ManageServable(
1723  CreateServableData(id, std::unique_ptr<Loader>(loader))));
1724  basic_manager->LoadServable(
1725  id, [](const auto& status) { EXPECT_FALSE(status.ok()); });
1726 
1727  // Make sure the final state is kEnd.
1728  WaitUntilServableManagerStateIsOneOf(
1729  servable_state_monitor, id,
1730  {ServableState::ManagerState::kEnd,
1731  ServableState::ManagerState::kAvailable});
1732  const auto final_state = servable_state_monitor.GetState(id);
1733  ASSERT_TRUE(final_state.has_value());
1734  EXPECT_EQ(final_state->manager_state, ServableState::ManagerState::kEnd);
1735  EXPECT_FALSE(final_state->health.ok());
1736  EXPECT_EQ(final_state->health.message(), "Non-retriable error.");
1737 }
1738 
1739 } // namespace
1740 } // namespace serving
1741 } // namespace tensorflow
static std::shared_ptr< EventBus > CreateEventBus(const Options &options={})
Definition: event_bus.h:191
@ kReady
'loader_->Load()' has succeeded.
@ kDisabled
'loader_->Unload()' has finished.
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)