16 #include "tensorflow_serving/core/basic_manager.h"
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/status/status.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/blocking_counter.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/null_file_system.h"
37 #include "tensorflow/core/protobuf/error_codes.pb.h"
38 #include "tensorflow_serving/core/servable_state_monitor.h"
39 #include "tensorflow_serving/core/test_util/availability_test_util.h"
40 #include "tensorflow_serving/core/test_util/fake_loader.h"
41 #include "tensorflow_serving/core/test_util/manager_test_util.h"
42 #include "tensorflow_serving/core/test_util/mock_loader.h"
43 #include "tensorflow_serving/util/any_ptr.h"
44 #include "tensorflow_serving/util/event_bus.h"
45 #include "tensorflow_serving/util/threadpool_executor.h"
47 namespace tensorflow {
51 using test_util::FakeLoader;
52 using test_util::WaitUntilServableManagerStateIsOneOf;
54 using ::testing::AnyOf;
55 using ::testing::HasSubstr;
56 using ::testing::InSequence;
57 using ::testing::Invoke;
58 using ::testing::InvokeWithoutArgs;
59 using ::testing::MockFunction;
60 using ::testing::NiceMock;
61 using ::testing::Return;
62 using ::testing::UnorderedElementsAre;
63 using ::testing::UnorderedElementsAreArray;
65 constexpr
char kServableName[] =
"kServableName";
66 constexpr
char kServableName2[] =
"kServableName2";
67 constexpr
char kServableName3[] =
"kServableName3";
69 constexpr
int kNumVersionsPerServable = 2;
71 constexpr
int kNumThreads = 10;
73 MATCHER_P(EqualsServableState, servable_state, servable_state.DebugString()) {
74 if (arg == servable_state) {
77 *result_listener << arg.DebugString();
82 ServableData<std::unique_ptr<Loader>> CreateServable(
83 const ServableId&
id,
const Status load_status = OkStatus()) {
84 std::unique_ptr<Loader> loader(
new FakeLoader(
id.version, load_status));
85 return CreateServableData(
id, std::move(loader));
90 struct ThreadPoolSizes {
91 uint64_t num_load_threads;
92 uint64_t num_unload_threads;
94 class BasicManagerTest :
public ::testing::TestWithParam<ThreadPoolSizes> {
97 : thread_pool_sizes_(GetParam()),
98 servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
99 servable_state_monitor_(servable_event_bus_.get()) {
100 BasicManager::Options options;
101 options.num_load_threads = thread_pool_sizes_.num_load_threads;
102 options.num_unload_threads = thread_pool_sizes_.num_unload_threads;
103 options.servable_event_bus = servable_event_bus_.get();
104 options.max_num_load_retries = 10;
105 options.load_retry_interval_micros = 0;
106 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
109 void SetUp()
override {
112 std::set<ServableId> loaded_servables;
113 for (
const char* servable_name : {kServableName, kServableName2}) {
114 for (
int i = 1; i <= kNumVersionsPerServable; ++i) {
115 const ServableId
id = {servable_name, i};
116 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
117 basic_manager_->LoadServable(
118 id, [](
const Status& status) { TF_ASSERT_OK(status); });
119 loaded_servables.insert(
id);
122 for (
const ServableId& loaded_servable : loaded_servables) {
123 WaitUntilServableManagerStateIsOneOf(
124 servable_state_monitor_, loaded_servable,
125 {ServableState::ManagerState::kAvailable});
129 ThreadPoolSizes thread_pool_sizes_;
130 std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
131 ServableStateMonitor servable_state_monitor_;
132 std::unique_ptr<BasicManager> basic_manager_;
135 INSTANTIATE_TEST_CASE_P(
136 WithOrWithoutThreadPools, BasicManagerTest,
138 ThreadPoolSizes{0, 0} ,
139 ThreadPoolSizes{2, 0} ,
140 ThreadPoolSizes{0, 2} ,
141 ThreadPoolSizes{4, 4} ));
143 TEST_P(BasicManagerTest, ServableHandleNotFoundMissingLoaderName) {
144 ServableHandle<int64_t> handle;
145 const Status status = basic_manager_->GetServableHandle(
146 ServableRequest::Latest(strings::StrCat(kServableName,
"missing")),
148 ASSERT_FALSE(status.ok()) << status;
149 EXPECT_EQ(error::NOT_FOUND, status.code());
152 TEST_P(BasicManagerTest, ServableHandleNotFoundMissingVersion) {
154 const int64_t missing_version = 100;
155 ServableHandle<int64_t> handle;
156 const Status status = basic_manager_->GetServableHandle(
157 ServableRequest::Specific(kServableName, missing_version), &handle);
158 ASSERT_FALSE(status.ok()) << status;
159 EXPECT_EQ(error::NOT_FOUND, status.code());
162 TEST_P(BasicManagerTest, ServableHandleEarliest) {
163 ASSERT_GT(kNumVersionsPerServable, 1);
164 ServableHandle<int64_t> handle;
165 const Status status = basic_manager_->GetServableHandle(
166 ServableRequest::Earliest(kServableName), &handle);
167 TF_ASSERT_OK(status);
168 EXPECT_EQ(1, *handle);
171 TEST_P(BasicManagerTest, ServableHandleLatest) {
172 const ServableId
id = {kServableName, kNumVersionsPerServable + 1};
173 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
174 basic_manager_->LoadServable(
175 id, [](
const Status& status) { TF_ASSERT_OK(status); });
176 WaitUntilServableManagerStateIsOneOf(
177 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
179 ServableHandle<int64_t> handle;
180 const Status status = basic_manager_->GetServableHandle(
181 ServableRequest::Latest(kServableName), &handle);
182 TF_ASSERT_OK(status);
183 EXPECT_EQ(kNumVersionsPerServable + 1, *handle);
186 TEST_P(BasicManagerTest, AlreadyManagedError) {
187 const ServableId
id = {
"banana", 42};
188 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
189 EXPECT_FALSE(basic_manager_->ManageServable(CreateServable(
id)).ok());
193 TEST_P(BasicManagerTest, ServableHandleLatestVersionIsZero) {
194 const ServableId
id = {kServableName3, 1};
195 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
196 basic_manager_->LoadServable(
197 id, [](
const Status& status) { TF_ASSERT_OK(status); });
198 WaitUntilServableManagerStateIsOneOf(
199 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
201 ServableHandle<int64_t> handle;
202 const Status status = basic_manager_->GetServableHandle(
203 ServableRequest::Latest(kServableName3), &handle);
204 TF_ASSERT_OK(status);
205 EXPECT_EQ(1, *handle);
206 EXPECT_EQ(
id, handle.id());
209 TEST_P(BasicManagerTest, StopManagingUnknownId) {
210 const ServableId
id = {kServableName3, 1};
211 EXPECT_FALSE(basic_manager_->StopManagingServable(
id).ok());
214 TEST_P(BasicManagerTest, StopManagingActiveServable) {
215 const ServableId
id = {kServableName3, 1};
216 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
217 basic_manager_->LoadServable(
218 id, [](
const Status& status) { TF_EXPECT_OK(status); });
219 WaitUntilServableManagerStateIsOneOf(
220 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
221 EXPECT_FALSE(basic_manager_->StopManagingServable(
id).ok());
224 TEST_P(BasicManagerTest, StopManagingDisabledServable) {
225 const ServableId
id = {kServableName3, 1};
226 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
227 basic_manager_->LoadServable(
228 id, [](
const Status& status) { TF_EXPECT_OK(status); });
229 WaitUntilServableManagerStateIsOneOf(
230 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
231 basic_manager_->UnloadServable(
232 id, [](
const Status& status) { TF_EXPECT_OK(status); });
233 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
234 {ServableState::ManagerState::kEnd});
235 const absl::optional<ServableStateSnapshot<>> snapshot =
236 basic_manager_->GetManagedServableStateSnapshot(
id);
238 const ServableState expected_state = {id, ServableState::ManagerState::kEnd,
240 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
241 EqualsServableState(expected_state));
243 TF_ASSERT_OK(basic_manager_->StopManagingServable(
id));
244 EXPECT_FALSE(basic_manager_->GetManagedServableStateSnapshot(
id));
247 TEST_P(BasicManagerTest, DontStopManagingOnError) {
248 const ServableId
id = {kServableName, 7};
249 const Status error_status = errors::Internal(
"An error.");
250 std::unique_ptr<Loader> loader(
new FakeLoader(7, error_status));
251 TF_ASSERT_OK(basic_manager_->ManageServable({id, std::move(loader)}));
252 basic_manager_->LoadServable(
id, [error_status](
const Status& status) {
253 EXPECT_EQ(error_status, status);
255 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
256 {ServableState::ManagerState::kEnd});
257 const absl::optional<ServableStateSnapshot<>> snapshot =
258 basic_manager_->GetManagedServableStateSnapshot(
id);
260 const ServableState expected_error_state = {
261 id, ServableState::ManagerState::kEnd, error_status};
262 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
263 EqualsServableState(expected_error_state));
266 TEST_P(BasicManagerTest, ServableHandleSpecificVersion) {
267 ServableHandle<int64_t> handle;
268 const ServableId
id = {kServableName2, 1};
269 const Status status =
270 basic_manager_->GetServableHandle(ServableRequest::FromId(
id), &handle);
271 TF_ASSERT_OK(status);
272 EXPECT_EQ(1, *handle);
273 EXPECT_EQ(
id, handle.id());
278 TEST_P(BasicManagerTest, UpdateServingMapServableHandleLatest) {
281 const ServableId id0 = {kServableName3, 0};
283 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
284 basic_manager_->LoadServable(
285 id0, [](
const Status& status) { TF_ASSERT_OK(status); });
286 WaitUntilServableManagerStateIsOneOf(
287 servable_state_monitor_, id0, {ServableState::ManagerState::kAvailable});
289 test_util::MockLoader* notify_to_unload =
new NiceMock<test_util::MockLoader>;
292 int64_t servable = 1;
293 ON_CALL(*notify_to_unload, servable())
294 .WillByDefault(Return(AnyPtr(&servable)));
295 ON_CALL(*notify_to_unload, EstimateResources(_))
296 .WillByDefault(Return(OkStatus()));
297 ON_CALL(*notify_to_unload, LoadWithMetadata(Loader::Metadata{id0}))
298 .WillByDefault(Return(OkStatus()));
299 const ServableId id1 = {kServableName3, 1};
300 TF_ASSERT_OK(basic_manager_->ManageServable(
301 {id1, std::unique_ptr<Loader>(notify_to_unload)}));
302 basic_manager_->LoadServable(
303 id1, [](
const Status& status) { TF_ASSERT_OK(status); });
304 WaitUntilServableManagerStateIsOneOf(
305 servable_state_monitor_, id1, {ServableState::ManagerState::kAvailable});
310 ServableHandle<int64_t> handle;
311 const Status status = basic_manager_->GetServableHandle(
312 ServableRequest::Latest(kServableName3), &handle);
313 TF_ASSERT_OK(status);
314 EXPECT_EQ(id1, handle.id());
320 Notification unload_started;
321 Notification finish_unload;
322 EXPECT_CALL(*notify_to_unload, Unload()).WillOnce(Invoke([&]() {
323 unload_started.Notify();
324 finish_unload.WaitForNotification();
326 Notification unload_finished;
327 std::unique_ptr<Thread> unload_last_servable(
328 Env::Default()->StartThread({},
"UnloadLastServable", [&]() {
329 basic_manager_->UnloadServable(id1, [&](
const Status& status) {
330 TF_EXPECT_OK(status);
331 unload_finished.Notify();
334 unload_started.WaitForNotification();
338 ServableHandle<int64_t> handle;
339 const Status status = basic_manager_->GetServableHandle(
340 ServableRequest::Latest(kServableName3), &handle);
341 TF_EXPECT_OK(status);
342 EXPECT_EQ(id0, handle.id());
344 finish_unload.Notify();
348 unload_finished.WaitForNotification();
351 TEST_P(BasicManagerTest, ListAvailableServableIds) {
352 const std::vector<ServableId> expected_before = {{kServableName, 1},
355 {kServableName2, 2}};
356 EXPECT_THAT(basic_manager_->ListAvailableServableIds(),
357 UnorderedElementsAreArray(expected_before));
361 const ServableId
id = {kServableName, 7};
362 std::unique_ptr<Loader> loader(
363 new FakeLoader(7, errors::Internal(
"An error.")));
364 TF_ASSERT_OK(basic_manager_->ManageServable(
365 CreateServableData(
id, std::move(loader))));
366 basic_manager_->LoadServable(
id, [](
const Status& status) {
367 EXPECT_EQ(errors::Internal(
"An error."), status);
369 basic_manager_->UnloadServable(
370 {kServableName, 1}, [](
const Status& status) { TF_ASSERT_OK(status); });
371 basic_manager_->UnloadServable(
372 {kServableName, 2}, [](
const Status& status) { TF_ASSERT_OK(status); });
373 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
374 {ServableState::ManagerState::kEnd});
375 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
377 {ServableState::ManagerState::kEnd});
378 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
380 {ServableState::ManagerState::kEnd});
382 const std::vector<ServableId> expected_after = {{kServableName2, 1},
383 {kServableName2, 2}};
384 EXPECT_THAT(basic_manager_->ListAvailableServableIds(),
385 UnorderedElementsAreArray(expected_after));
388 TEST_P(BasicManagerTest, GetAvailableServableHandles) {
391 const std::map<ServableId, ServableHandle<int64_t>> handles_before =
392 basic_manager_->GetAvailableServableHandles<int64_t>();
393 ASSERT_EQ(kNumVersionsPerServable * 2, handles_before.size());
395 const std::vector<ServableId> expected_ids_before = {{kServableName, 1},
398 {kServableName2, 2}};
399 for (
const ServableId& expected_id : expected_ids_before) {
400 const auto found_it = handles_before.find(expected_id);
401 ASSERT_TRUE(found_it != handles_before.end());
402 EXPECT_EQ(expected_id.version, *found_it->second);
408 const ServableId
id = {kServableName, 7};
409 std::unique_ptr<Loader> loader(
410 new FakeLoader(7, errors::Internal(
"An error.")));
411 TF_ASSERT_OK(basic_manager_->ManageServable(
412 CreateServableData(
id, std::move(loader))));
413 basic_manager_->LoadServable(
id, [](
const Status& status) {
414 EXPECT_EQ(errors::Internal(
"An error."), status);
416 basic_manager_->UnloadServable(
417 {kServableName, 1}, [](
const Status& status) { TF_ASSERT_OK(status); });
418 basic_manager_->UnloadServable(
419 {kServableName, 2}, [](
const Status& status) { TF_ASSERT_OK(status); });
420 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
421 {ServableState::ManagerState::kEnd});
422 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
424 {ServableState::ManagerState::kEnd});
425 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
427 {ServableState::ManagerState::kEnd});
430 const std::map<ServableId, ServableHandle<int64_t>> handles_after =
431 basic_manager_->GetAvailableServableHandles<int64_t>();
432 ASSERT_EQ(kNumVersionsPerServable, handles_after.size());
434 const std::vector<ServableId> expected_ids_after = {{kServableName2, 1},
435 {kServableName2, 2}};
436 for (
const ServableId& expected_id : expected_ids_after) {
437 const auto found_it = handles_after.find(expected_id);
438 ASSERT_TRUE(found_it != handles_after.end());
439 EXPECT_EQ(expected_id.version, *found_it->second);
444 TEST_P(BasicManagerTest, GetAvailableServableHandlesWrongType) {
445 const std::map<ServableId, ServableHandle<int>> wrong_type_handles =
446 basic_manager_->GetAvailableServableHandles<
int>();
447 EXPECT_EQ(0, wrong_type_handles.size());
450 TEST_P(BasicManagerTest, GetManagedServableNames) {
451 EXPECT_THAT(basic_manager_->GetManagedServableNames(),
452 UnorderedElementsAre(kServableName, kServableName2));
455 TEST_P(BasicManagerTest,
456 GetManagedServableStateSnapshotWithoutAdditionalState) {
457 const std::vector<ServableStateSnapshot<>> expected = {
460 EXPECT_THAT(basic_manager_->GetManagedServableStateSnapshots(kServableName),
461 UnorderedElementsAreArray(expected));
464 TEST_P(BasicManagerTest, GetManagedServableStateSnapshot) {
467 const ServableId id_ready = {kServableName, 1};
468 const absl::optional<ServableStateSnapshot<>> actual_ready_snapshot =
469 basic_manager_->GetManagedServableStateSnapshot(id_ready);
470 EXPECT_TRUE(actual_ready_snapshot);
471 const ServableStateSnapshot<> expected_ready_snapshot = {
473 EXPECT_EQ(actual_ready_snapshot, expected_ready_snapshot);
477 const ServableId id_notmanaged = {kServableName, 8};
478 EXPECT_FALSE(basic_manager_->GetManagedServableStateSnapshot(id_notmanaged));
481 TEST_P(BasicManagerTest, GetManagedServableStateSnapshotsWithAdditionalState) {
482 TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
483 CreateServable({kServableName3, 0}), std::unique_ptr<int>(
new int(0))));
484 TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
485 CreateServable({kServableName3, 1}), std::unique_ptr<int>(
new int(1))));
486 const std::vector<ServableStateSnapshot<int>> expected = {
490 basic_manager_->GetManagedServableStateSnapshots<
int>(kServableName3),
491 UnorderedElementsAreArray(expected));
494 TEST_P(BasicManagerTest, MultipleManageCallsUsesFirstServable) {
495 const ServableId
id = {kServableName, 1};
499 std::unique_ptr<Loader> first_ignored_loader(
500 new FakeLoader(1, errors::Internal(
"An error.")));
501 EXPECT_FALSE(basic_manager_
503 CreateServableData(
id, std::move(first_ignored_loader)))
507 std::unique_ptr<Loader> second_ignored_loader(
508 new FakeLoader(2, errors::Internal(
"An error.")));
509 EXPECT_FALSE(basic_manager_
511 CreateServableData(
id, std::move(second_ignored_loader)))
514 ServableHandle<int64_t> handle;
515 TF_ASSERT_OK(basic_manager_->GetServableHandle(
516 ServableRequest::Specific(kServableName, 1), &handle));
517 EXPECT_EQ(1, *handle);
522 TEST_P(BasicManagerTest, ErroneousServable) {
523 const ServableId
id = {kServableName, 3};
524 TF_ASSERT_OK(basic_manager_->ManageServable(
525 ServableData<std::unique_ptr<Loader>>(
id, errors::Unknown(
"error"))));
527 ServableHandle<int64_t> handle;
528 Status status = basic_manager_->GetServableHandle(
529 ServableRequest::Specific(kServableName, 3), &handle);
530 EXPECT_FALSE(status.ok()) << status;
531 basic_manager_->LoadServable(
532 id, [](
const Status& status) { EXPECT_FALSE(status.ok()) << status; });
534 status = basic_manager_->GetServableHandle(
535 ServableRequest::Specific(kServableName, 3), &handle);
536 EXPECT_FALSE(status.ok()) << status;
541 TEST_P(BasicManagerTest, DestructOnNonServingThread) {
542 const ServableId
id = {kServableName, 7};
543 TF_ASSERT_OK(basic_manager_->ManageServable(
544 CreateServableData(
id, std::unique_ptr<Loader>(
new FakeLoader(7)))));
545 basic_manager_->LoadServable(
546 id, [](
const Status& status) { TF_ASSERT_OK(status); });
547 WaitUntilServableManagerStateIsOneOf(
548 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
550 std::unique_ptr<ServableHandle<int64_t>> latest_handle(
551 new ServableHandle<int64_t>());
552 const Status status = basic_manager_->GetServableHandle(
553 ServableRequest::Latest(kServableName), latest_handle.get());
554 TF_ASSERT_OK(status);
555 EXPECT_EQ(7, **latest_handle);
557 Notification done_unload_servable;
558 std::unique_ptr<Thread> unload_servable(
559 Env::Default()->StartThread({},
"UnloadServable", [&]() {
561 basic_manager_->UnloadServable(
562 id, [](
const Status& status) { TF_ASSERT_OK(status); });
563 WaitUntilServableManagerStateIsOneOf(
564 servable_state_monitor_,
id, {ServableState::ManagerState::kEnd});
566 TF_ASSERT_OK(basic_manager_->StopManagingServable(
id));
569 if (thread_pool_sizes_.num_load_threads == 0) {
570 EXPECT_TRUE(FakeLoader::was_deleted_in_this_thread());
572 done_unload_servable.Notify();
576 latest_handle.reset();
577 done_unload_servable.WaitForNotification();
579 ASSERT_FALSE(FakeLoader::was_deleted_in_this_thread());
582 TEST_P(BasicManagerTest, AdditionalState) {
583 const ServableId
id = {kServableName, 3};
584 std::unique_ptr<int> state(
new int(1));
585 TF_CHECK_OK(basic_manager_->ManageServableWithAdditionalState(
586 CreateServable(
id), std::move(state)));
588 EXPECT_EQ(1, *basic_manager_->GetAdditionalServableState<
int>(
id));
589 EXPECT_EQ(
nullptr, basic_manager_->GetAdditionalServableState<
float>(
id));
592 TEST_P(BasicManagerTest, NoAdditionalState) {
593 const ServableId
id = {kServableName, 3};
594 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
597 EXPECT_EQ(
nullptr, basic_manager_->GetAdditionalServableState<
int>(
id));
598 EXPECT_EQ(
nullptr, basic_manager_->GetAdditionalServableState<
float>(
id));
601 TEST_P(BasicManagerTest, OutOfOrderLoadServable) {
602 const ServableId
id = {kServableName, 3};
603 basic_manager_->LoadServable(
id, [](
const Status& status) {
604 EXPECT_FALSE(status.ok());
605 EXPECT_EQ(error::NOT_FOUND, status.code());
606 EXPECT_THAT(status.message(), HasSubstr(
"is not being managed"));
610 TEST_P(BasicManagerTest, MultipleLoadServables) {
611 const ServableId
id = {kServableName, 3};
612 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
613 basic_manager_->LoadServable(
614 id, [](
const Status& status) { TF_ASSERT_OK(status); });
615 WaitUntilServableManagerStateIsOneOf(
616 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
617 basic_manager_->LoadServable(
id, [](
const Status& status) {
618 EXPECT_FALSE(status.ok());
619 EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
620 EXPECT_THAT(status.message(), HasSubstr(
"Duplicate load request"));
624 TEST_P(BasicManagerTest, MultipleUnloadServables) {
625 const ServableId
id = {kServableName, 3};
626 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
627 basic_manager_->LoadServable(
628 id, [](
const Status& status) { TF_ASSERT_OK(status); });
629 WaitUntilServableManagerStateIsOneOf(
630 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
631 basic_manager_->UnloadServable(
632 id, [](
const Status& status) { TF_ASSERT_OK(status); });
633 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
634 {ServableState::ManagerState::kEnd});
635 basic_manager_->UnloadServable(
id, [](
const Status& status) {
636 EXPECT_FALSE(status.ok());
637 EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
638 EXPECT_THAT(status.message(),
639 HasSubstr(
"unload already requested/ongoing"));
643 TEST_P(BasicManagerTest, UnloadWithoutManage) {
644 const ServableId
id = {kServableName, 3};
645 basic_manager_->UnloadServable(
id, [](
const Status& status) {
646 EXPECT_FALSE(status.ok());
647 EXPECT_EQ(error::NOT_FOUND, status.code());
648 EXPECT_THAT(status.message(), HasSubstr(
"is not being managed"));
652 TEST_P(BasicManagerTest, UnloadWithoutLoad) {
653 const ServableId
id = {kServableName, 3};
654 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
655 basic_manager_->UnloadServable(
id, [](
const Status& status) {
656 EXPECT_FALSE(status.ok());
657 EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
658 EXPECT_THAT(status.message(), HasSubstr(
"Servable not loaded"));
662 TEST_P(BasicManagerTest, EventBusErroneousVersion) {
663 const ServableId
id = {kServableName, 3};
664 TF_ASSERT_OK(basic_manager_->ManageServable(
665 ServableData<std::unique_ptr<Loader>>(
id, errors::Unknown(
"error"))));
667 const ServableState expected_published_state = {
668 id, ServableState::ManagerState::kEnd, errors::Unknown(
"error")};
669 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
670 EqualsServableState(expected_published_state));
673 TEST_P(BasicManagerTest, EventBusErrorOnLoad) {
674 const ServableId
id = {kServableName, 7};
675 std::unique_ptr<Loader> loader(
676 new FakeLoader(7, errors::Internal(
"Error on load.")));
677 TF_ASSERT_OK(basic_manager_->ManageServable({id, std::move(loader)}));
679 const ServableState start_state = {id, ServableState::ManagerState::kStart,
681 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
682 EqualsServableState(start_state));
684 basic_manager_->LoadServable(
id, [](
const Status& status) {});
685 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
686 {ServableState::ManagerState::kEnd});
688 const ServableState error_state = {id, ServableState::ManagerState::kEnd,
689 errors::Internal(
"Error on load.")};
690 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
691 EqualsServableState(error_state));
694 TEST_P(BasicManagerTest, EventBusServableLifecycle) {
695 const ServableId
id = {kServableName, 7};
696 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
698 basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
700 const ServableState start_state = {id, ServableState::ManagerState::kStart,
702 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
703 EqualsServableState(start_state));
705 Notification load_called;
706 Notification load_continue;
707 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
708 .WillOnce(InvokeWithoutArgs([&]() {
709 load_called.Notify();
710 load_continue.WaitForNotification();
714 std::unique_ptr<Thread> load_thread(
715 Env::Default()->StartThread(ThreadOptions(),
"LoadThread", [&]() {
716 basic_manager_->LoadServable(
id, [](
const Status& status) {});
719 load_called.WaitForNotification();
721 const ServableState loading_state = {
722 id, ServableState::ManagerState::kLoading, OkStatus()};
723 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
724 EqualsServableState(loading_state));
726 load_continue.Notify();
727 WaitUntilServableManagerStateIsOneOf(
728 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
730 const ServableState available_state = {
731 id, ServableState::ManagerState::kAvailable, OkStatus()};
732 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
733 EqualsServableState(available_state));
735 Notification unload_called;
736 Notification unload_continue;
737 EXPECT_CALL(*loader, Unload()).WillOnce(Invoke([&]() {
738 unload_called.Notify();
739 unload_continue.WaitForNotification();
742 std::unique_ptr<Thread> unload_thread(
743 Env::Default()->StartThread(ThreadOptions(),
"UnloadThread", [&]() {
744 basic_manager_->UnloadServable(
id, [](
const Status& status) {});
747 unload_called.WaitForNotification();
749 const ServableState unloading_state = {
750 id, ServableState::ManagerState::kUnloading, OkStatus()};
751 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
752 EqualsServableState(unloading_state));
754 unload_continue.Notify();
755 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
756 {ServableState::ManagerState::kEnd});
758 const ServableState end_state = {id, ServableState::ManagerState::kEnd,
760 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
761 EqualsServableState(end_state));
765 TEST_P(BasicManagerTest, NoEventBus) {
766 BasicManager::Options options;
768 options.num_load_threads = 0;
770 options.servable_event_bus =
nullptr;
771 std::unique_ptr<BasicManager> manager;
772 TF_ASSERT_OK(BasicManager::Create(std::move(options), &manager));
774 const ServableId
id = {kServableName, 7};
775 std::unique_ptr<Loader> loader(
new FakeLoader(7));
776 TF_ASSERT_OK(manager->ManageServable({id, std::move(loader)}));
777 manager->LoadServable(
id, [](
const Status& status) { TF_ASSERT_OK(status); });
778 manager->UnloadServable(
id,
779 [](
const Status& status) { TF_ASSERT_OK(status); });
782 TEST_P(BasicManagerTest, LoadsThenUnloads) {
783 std::set<ServableId> servables;
786 ThreadPoolExecutor load_executor(Env::Default(),
"LoadServables",
788 for (
int i = 0; i < 20; ++i) {
789 const ServableId
id = {kServableName3, i};
790 servables.insert(
id);
791 load_executor.Schedule([
this,
id]() {
792 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
793 basic_manager_->LoadServable(
794 id, [](
const Status& status) { TF_ASSERT_OK(status); });
800 for (
const ServableId& servable : servables) {
801 WaitUntilServableManagerStateIsOneOf(
802 servable_state_monitor_, servable,
803 {ServableState::ManagerState::kAvailable});
807 ThreadPoolExecutor unload_executor(Env::Default(),
"UnloadServables",
810 for (
int i = 19; i >= 0; --i) {
811 unload_executor.Schedule([
this, i]() {
812 const ServableId
id = {kServableName3, i};
813 basic_manager_->UnloadServable(
814 id, [](
const Status& status) { TF_ASSERT_OK(status); });
820 TEST_P(BasicManagerTest, InterleavedLoadsAndUnloads) {
821 ThreadPoolExecutor executor(Env::Default(),
"InterleavedLoadsAndUnloads",
823 for (
int i = 0; i < 20; ++i) {
824 executor.Schedule([
this, i]() {
825 const ServableId
id = {kServableName3, i};
826 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
827 Notification load_done;
828 basic_manager_->LoadServable(
id, [&load_done](
const Status& status) {
829 TF_ASSERT_OK(status);
832 load_done.WaitForNotification();
833 basic_manager_->UnloadServable(
834 id, [](
const Status& status) { TF_ASSERT_OK(status); });
839 class SetNumLoadThreadsBasicManagerTest :
public ::testing::Test {
841 SetNumLoadThreadsBasicManagerTest() {
842 BasicManager::Options options;
843 options.num_load_threads = 0;
844 options.max_num_load_retries = 10;
845 options.load_retry_interval_micros = 0;
846 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
849 std::unique_ptr<BasicManager> basic_manager_;
852 TEST_F(SetNumLoadThreadsBasicManagerTest, ThreadPoolSwapped) {
853 test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
854 manager_test_access.SetNumLoadThreads(2);
855 EXPECT_EQ(2, manager_test_access.num_load_threads());
857 const auto load_done_fn = [&](
const Status& status) {
858 TF_ASSERT_OK(status);
861 static thread_local
int per_thread_load_ctr = 0;
862 ++per_thread_load_ctr;
863 EXPECT_EQ(1, per_thread_load_ctr);
866 const ServableId id0 = {kServableName3, 0};
867 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
868 basic_manager_->LoadServable(id0, load_done_fn);
870 manager_test_access.SetNumLoadThreads(0);
871 EXPECT_EQ(0, manager_test_access.num_load_threads());
873 const ServableId id1 = {kServableName3, 1};
874 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
875 basic_manager_->LoadServable(id1, load_done_fn);
878 basic_manager_.reset();
881 TEST_F(SetNumLoadThreadsBasicManagerTest, ThreadPoolsNotAliveSimultaneously) {
882 test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
883 manager_test_access.SetNumLoadThreads(1);
884 EXPECT_EQ(1, manager_test_access.num_load_threads());
886 std::set<string> data_race_set;
887 const auto data_race_fn = [&](
const Status& status) {
891 data_race_set.insert(
"string");
894 const ServableId id0 = {kServableName3, 0};
895 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
896 Notification notify_for_setting;
897 Notification continue_load;
898 basic_manager_->LoadServable(id0, [&](
const Status& status) {
899 notify_for_setting.Notify();
900 continue_load.WaitForNotification();
901 data_race_fn(status);
905 ThreadPoolExecutor executor(Env::Default(),
"SetNumLoadThreads",
907 executor.Schedule([&]() {
908 notify_for_setting.WaitForNotification();
909 manager_test_access.SetNumLoadThreads(1);
910 EXPECT_EQ(1, manager_test_access.num_load_threads());
913 executor.Schedule([&]() {
914 const ServableId id1 = {kServableName3, 1};
915 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
916 continue_load.Notify();
917 basic_manager_->LoadServable(
918 id1, [&](
const Status& status) { data_race_fn(status); });
923 basic_manager_.reset();
928 TEST_F(SetNumLoadThreadsBasicManagerTest, FastLoad) {
929 test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
930 const uint32 prev_num_load_threads = manager_test_access.num_load_threads();
931 manager_test_access.SetNumLoadThreads(32);
932 EXPECT_EQ(32, manager_test_access.num_load_threads());
935 ThreadPoolExecutor executor(Env::Default(),
"FirstThreadPoolLoads",
937 for (
int i = 0; i < 20; ++i) {
938 executor.Schedule([
this, i]() {
939 const ServableId
id = {kServableName3, i};
940 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
941 basic_manager_->LoadServable(
942 id, [](
const Status& status) { TF_ASSERT_OK(status); });
951 manager_test_access.SetNumLoadThreads(prev_num_load_threads);
952 EXPECT_EQ(prev_num_load_threads, manager_test_access.num_load_threads());
955 ThreadPoolExecutor executor(Env::Default(),
"Unloads", kNumThreads);
956 for (
int i = 0; i < 20; ++i) {
957 executor.Schedule([
this, i]() {
958 const ServableId
id = {kServableName3, i};
959 basic_manager_->UnloadServable(
960 id, [](
const Status& status) { TF_ASSERT_OK(status); });
969 class FlushDetectingFileSystem :
public NullFileSystem {
971 void FlushCaches()
override { flushed =
true; }
972 static std::atomic<bool> flushed;
975 std::atomic<bool> FlushDetectingFileSystem::flushed;
977 REGISTER_FILE_SYSTEM(
"flush", FlushDetectingFileSystem);
982 class FlushFileSystemCachesTest :
public ::testing::TestWithParam<bool> {
984 FlushFileSystemCachesTest() : flush_filesystem_caches_(GetParam()) {
985 BasicManager::Options options;
986 options.flush_filesystem_caches = flush_filesystem_caches_;
987 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
990 std::unique_ptr<BasicManager> basic_manager_;
991 bool flush_filesystem_caches_;
994 TEST_P(FlushFileSystemCachesTest, Load) {
995 test_util::BasicManagerTestAccess manager_test_access(basic_manager_.get());
998 FlushDetectingFileSystem::flushed.store(
false);
999 const ServableId id0 = {kServableName3, 0};
1000 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id0)));
1001 basic_manager_->LoadServable(id0, [&](
const Status& status) {
1002 TF_ASSERT_OK(status);
1003 EXPECT_EQ(flush_filesystem_caches_,
1004 FlushDetectingFileSystem::flushed.load());
1008 manager_test_access.SetNumLoadThreads(2);
1009 FlushDetectingFileSystem::flushed.store(
false);
1010 const ServableId id1 = {kServableName3, 1};
1011 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id1)));
1012 basic_manager_->LoadServable(id1, [&](
const Status& status) {
1013 TF_ASSERT_OK(status);
1014 EXPECT_FALSE(FlushDetectingFileSystem::flushed.load());
1018 manager_test_access.SetNumLoadThreads(1);
1019 FlushDetectingFileSystem::flushed.store(
false);
1020 const ServableId id2 = {kServableName3, 2};
1021 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(id2)));
1022 basic_manager_->LoadServable(id2, [&](
const Status& status) {
1023 TF_ASSERT_OK(status);
1024 EXPECT_EQ(flush_filesystem_caches_,
1025 FlushDetectingFileSystem::flushed.load());
1027 basic_manager_.reset();
1030 INSTANTIATE_TEST_CASE_P(WithOrWithoutFlush, FlushFileSystemCachesTest,
1033 TEST_P(BasicManagerTest, ConcurrentLoadsOnlyOneSucceeds) {
1034 const ServableId
id = {kServableName3, 0};
1036 std::vector<Status> statuses(4);
1038 ThreadPoolExecutor load_executor(Env::Default(),
"LoadServables",
1040 for (
int i = 0; i < 4; ++i) {
1041 load_executor.Schedule([
this,
id, i, &statuses, &status_mu]() {
1043 basic_manager_->ManageServable(CreateServable(
id)).IgnoreError();
1044 basic_manager_->LoadServable(
1045 id, [i, &statuses, &status_mu](
const Status& status) {
1046 mutex_lock l(status_mu);
1047 statuses[i] = status;
1055 basic_manager_.reset();
1057 int num_status_ok = 0;
1058 for (
int i = 0; i < 4; ++i) {
1059 mutex_lock l(status_mu);
1060 if (!statuses[i].ok()) {
1061 EXPECT_EQ(error::FAILED_PRECONDITION, statuses[i].code());
1062 EXPECT_THAT(statuses[i].message(), HasSubstr(
"Duplicate load request"));
1067 EXPECT_EQ(1, num_status_ok);
1070 TEST_P(BasicManagerTest, ConcurrentUnloadsOnlyOneSucceeds) {
1071 const ServableId
id = {kServableName3, 0};
1072 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServable(
id)));
1073 basic_manager_->LoadServable(
1074 id, [](
const Status& status) { TF_ASSERT_OK(status); });
1076 WaitUntilServableManagerStateIsOneOf(
1077 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
1080 std::vector<Status> statuses(4);
1082 ThreadPoolExecutor load_executor(Env::Default(),
"LoadServables",
1084 for (
int i = 0; i < 4; ++i) {
1085 load_executor.Schedule([
this,
id, i, &statuses, &status_mu]() {
1086 basic_manager_->UnloadServable(
1087 id, [i, &statuses, &status_mu](
const Status& status) {
1088 mutex_lock l(status_mu);
1089 statuses[i] = status;
1097 basic_manager_.reset();
1099 int num_status_ok = 0;
1100 for (
int i = 0; i < 4; ++i) {
1101 mutex_lock l(status_mu);
1103 if (!statuses[i].ok()) {
1104 ASSERT_THAT(statuses[i].code(),
1105 AnyOf(error::NOT_FOUND, error::FAILED_PRECONDITION));
1106 if (statuses[i].code() == error::NOT_FOUND) {
1107 EXPECT_THAT(statuses[i].message(), HasSubstr(
"not being managed"));
1109 EXPECT_THAT(statuses[i].message(),
1110 HasSubstr(
"unload already requested/ongoing"));
1116 EXPECT_EQ(1, num_status_ok);
1119 TEST_P(BasicManagerTest, RetryOnLoadErrorFinallySucceeds) {
1120 const ServableId
id = {kServableName, 7};
1121 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
1123 basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1124 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1125 .WillOnce(Return(errors::Internal(
"Load error.")))
1126 .WillRepeatedly(Return(OkStatus()));
1127 basic_manager_->LoadServable(
1128 id, [](
const Status& status) { TF_ASSERT_OK(status); });
1131 TEST_P(BasicManagerTest, RetryOnLoadErrorFinallyFails) {
1132 const ServableId
id = {kServableName, 7};
1133 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
1135 basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1136 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1137 .WillRepeatedly(Return(errors::Internal(
"Load error.")));
1138 basic_manager_->LoadServable(
id, [](
const Status& status) {
1139 EXPECT_EQ(errors::Internal(
"Load error."), status);
1144 TEST_P(BasicManagerTest, RetryOnLoadErrorCancelledLoad) {
1145 const ServableId
id = {kServableName, 7};
1146 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
1148 basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1150 Notification load_called;
1151 Notification load_should_return;
1152 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1153 .WillOnce(InvokeWithoutArgs([&load_called, &load_should_return]() {
1154 load_called.Notify();
1155 load_should_return.WaitForNotification();
1156 return errors::Internal(
"Load error.");
1158 .WillRepeatedly(Return(OkStatus()));
1159 std::unique_ptr<Thread> load_thread(
1160 Env::Default()->StartThread(ThreadOptions(),
"LoadServable", [&]() {
1161 basic_manager_->LoadServable(
id, [](
const Status& status) {
1162 EXPECT_EQ(errors::Internal(
"Load error."), status);
1165 load_called.WaitForNotification();
1166 basic_manager_->CancelLoadServableRetry(
id);
1167 load_should_return.Notify();
1168 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1169 {ServableState::ManagerState::kEnd});
1172 TEST_P(BasicManagerTest, LoadAfterCancelledLoad) {
1173 const ServableId
id = {kServableName, 7};
1174 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
1176 basic_manager_->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1178 Notification load_called;
1179 Notification load_should_return;
1180 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1181 .WillOnce(InvokeWithoutArgs([&load_called, &load_should_return]() {
1182 load_called.Notify();
1183 load_should_return.WaitForNotification();
1184 return errors::Internal(
"Load error.");
1186 .WillRepeatedly(Return(OkStatus()));
1188 std::unique_ptr<Thread> load_thread(
1189 Env::Default()->StartThread(ThreadOptions(),
"LoadServable", [&]() {
1190 basic_manager_->LoadServable(
id, [](
const Status& status) {
1191 EXPECT_EQ(errors::Internal(
"Load error."), status);
1194 load_called.WaitForNotification();
1195 basic_manager_->CancelLoadServableRetry(
id);
1196 load_should_return.Notify();
1197 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1198 {ServableState::ManagerState::kEnd});
1200 basic_manager_->LoadServable(
1201 id, [](
const Status& status) { EXPECT_FALSE(status.ok()) << status; });
1204 TEST(NonParameterizedBasicManagerTest, PreLoadHook) {
1205 BasicManager::Options options;
1207 options.num_load_threads = 0;
1209 options.servable_event_bus =
nullptr;
1210 MockFunction<void(
const ServableId&)> mock_pre_load_hook;
1211 options.pre_load_hook = mock_pre_load_hook.AsStdFunction();
1212 std::unique_ptr<BasicManager> manager;
1213 TF_ASSERT_OK(BasicManager::Create(std::move(options), &manager));
1215 const ServableId
id = {kServableName, 7};
1216 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
1217 TF_ASSERT_OK(manager->ManageServable({id, std::unique_ptr<Loader>(loader)}));
1219 bool pre_load_hook_called =
false;
1220 EXPECT_CALL(mock_pre_load_hook, Call(
id)).WillOnce(InvokeWithoutArgs([&]() {
1221 pre_load_hook_called =
true;
1223 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1224 .WillOnce(InvokeWithoutArgs([&]() {
1225 EXPECT_TRUE(pre_load_hook_called);
1228 manager->LoadServable(
id, [](
const Status& status) { TF_ASSERT_OK(status); });
1229 manager->UnloadServable(
id,
1230 [](
const Status& status) { TF_ASSERT_OK(status); });
1234 ResourceAllocation CreateResourceQuantity(
const int quantity) {
1235 ResourceAllocation allocation;
1236 auto* ram_resource = allocation.add_resource_quantities();
1237 ram_resource->mutable_resource()->set_device(
"main");
1238 ram_resource->mutable_resource()->set_kind(
"ram");
1239 ram_resource->set_quantity(quantity);
1245 std::unique_ptr<ResourceTracker> CreateSimpleResourceTracker(
1246 const int resource_quantity) {
1247 std::unique_ptr<ResourceUtil> util(
new ResourceUtil({{{
"main", 1}}}));
1248 std::unique_ptr<ResourceTracker> tracker;
1249 TF_CHECK_OK(ResourceTracker::Create(CreateResourceQuantity(resource_quantity),
1250 std::move(util), &tracker));
1254 class ResourceConstrainedBasicManagerTest :
public ::testing::Test {
1256 ResourceConstrainedBasicManagerTest()
1257 : servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
1258 servable_state_monitor_(servable_event_bus_.get()) {
1259 BasicManager::Options options;
1261 options.resource_tracker = CreateSimpleResourceTracker(10);
1262 options.servable_event_bus = servable_event_bus_.get();
1264 options.num_load_threads = 2;
1265 options.num_unload_threads = 2;
1267 options.max_num_load_retries = 0;
1268 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager_));
1271 std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
1272 ServableStateMonitor servable_state_monitor_;
1273 std::unique_ptr<BasicManager> basic_manager_;
1278 class BarrierLoader :
public Loader {
1280 explicit BarrierLoader(BlockingCounter* counter) : counter_(counter) {}
1281 ~BarrierLoader()
override =
default;
1283 Status EstimateResources(ResourceAllocation* estimate)
const override {
1284 *estimate = CreateResourceQuantity(5);
1288 Status Load()
override {
1289 counter_->DecrementCount();
1294 void Unload()
override {}
1296 AnyPtr servable()
override {
return AnyPtr(); }
1299 BlockingCounter*
const counter_;
1301 TF_DISALLOW_COPY_AND_ASSIGN(BarrierLoader);
1304 TEST_F(ResourceConstrainedBasicManagerTest, ConcurrentLoads) {
1308 int kNumLoaders = 2;
1309 BlockingCounter barrier(kNumLoaders);
1310 for (
int i = 0; i < kNumLoaders; ++i) {
1311 std::unique_ptr<Loader> loader(
new BarrierLoader(&barrier));
1312 const ServableId
id = {
"barrier", i};
1313 TF_ASSERT_OK(basic_manager_->ManageServable(
1314 CreateServableData(
id, std::move(loader))));
1315 basic_manager_->LoadServable(
1316 id, [](
const Status& status) { TF_EXPECT_OK(status); });
1319 basic_manager_.reset();
1322 TEST_F(ResourceConstrainedBasicManagerTest, InsufficientResources) {
1325 const ServableId hogging_id = {
"hogging", 0};
1326 test_util::MockLoader* hogging_loader =
new NiceMock<test_util::MockLoader>;
1327 ON_CALL(*hogging_loader, EstimateResources(_))
1328 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1329 *estimate = CreateResourceQuantity(10 );
1332 EXPECT_CALL(*hogging_loader, LoadWithMetadata(Loader::Metadata{hogging_id}))
1333 .WillOnce(Return(OkStatus()));
1334 TF_ASSERT_OK(basic_manager_->ManageServable(
1335 CreateServableData(hogging_id, std::unique_ptr<Loader>(hogging_loader))));
1336 Notification hogging_loaded;
1337 basic_manager_->LoadServable(hogging_id,
1338 [&hogging_loaded](
const Status& status) {
1339 TF_EXPECT_OK(status);
1340 hogging_loaded.Notify();
1342 hogging_loaded.WaitForNotification();
1345 const ServableId rejected_id = {
"rejected", 0};
1346 test_util::MockLoader* rejected_loader =
new NiceMock<test_util::MockLoader>;
1347 ON_CALL(*rejected_loader, EstimateResources(_))
1348 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1349 *estimate = CreateResourceQuantity(1);
1352 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1353 rejected_id, std::unique_ptr<Loader>(rejected_loader))));
1354 Notification rejection_received;
1355 Status rejected_status;
1356 basic_manager_->LoadServable(
1358 [&rejection_received, &rejected_status](
const Status& status) {
1359 ASSERT_FALSE(status.ok());
1360 ASSERT_EQ(error::RESOURCE_EXHAUSTED, status.code());
1361 rejected_status = status;
1362 rejection_received.Notify();
1364 rejection_received.WaitForNotification();
1365 const ServableState expected_error_state = {
1366 rejected_id, ServableState::ManagerState::kEnd, rejected_status};
1367 EXPECT_THAT(*servable_state_monitor_.
GetState(rejected_id),
1368 EqualsServableState(expected_error_state));
1371 const absl::optional<ServableStateSnapshot<>> snapshot =
1372 basic_manager_->GetManagedServableStateSnapshot(rejected_id);
1376 TEST_F(ResourceConstrainedBasicManagerTest, ResourcesReleasedIfLoadFails) {
1378 const ServableId failing_id = {
"failing", 0};
1379 test_util::MockLoader* failing_loader =
new NiceMock<test_util::MockLoader>;
1380 ON_CALL(*failing_loader, EstimateResources(_))
1381 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1382 *estimate = CreateResourceQuantity(10);
1385 EXPECT_CALL(*failing_loader, LoadWithMetadata(Loader::Metadata{failing_id}))
1386 .WillOnce(Return(errors::Unknown(
"Load failure")));
1387 TF_ASSERT_OK(basic_manager_->ManageServable(
1388 CreateServableData(failing_id, std::unique_ptr<Loader>(failing_loader))));
1389 Notification failing_failed;
1390 basic_manager_->LoadServable(failing_id,
1391 [&failing_failed](
const Status& status) {
1392 EXPECT_FALSE(status.ok());
1393 failing_failed.Notify();
1395 failing_failed.WaitForNotification();
1400 const ServableId succeeding_id = {
"succeeding", 0};
1401 test_util::MockLoader* succeeding_loader =
1402 new NiceMock<test_util::MockLoader>;
1403 ON_CALL(*succeeding_loader, EstimateResources(_))
1404 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1405 *estimate = CreateResourceQuantity(10);
1408 EXPECT_CALL(*succeeding_loader,
1409 LoadWithMetadata(Loader::Metadata{succeeding_id}))
1410 .WillOnce(Return(OkStatus()));
1411 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1412 succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1413 basic_manager_->LoadServable(
1414 succeeding_id, [](
const Status& status) { TF_EXPECT_OK(status); });
1417 TEST_F(ResourceConstrainedBasicManagerTest,
1418 ResourcesReleasedIfEstimateDecreasesAfterLoad) {
1420 const ServableId overestimating_id = {
"overestimating", 0};
1421 test_util::MockLoader* overestimating_loader =
1422 new NiceMock<test_util::MockLoader>;
1424 InSequence sequence;
1425 EXPECT_CALL(*overestimating_loader, EstimateResources(_))
1426 .WillOnce(Invoke([](ResourceAllocation* estimate) {
1427 *estimate = CreateResourceQuantity(10);
1430 .RetiresOnSaturation();
1431 EXPECT_CALL(*overestimating_loader,
1432 LoadWithMetadata(Loader::Metadata{overestimating_id}))
1433 .WillOnce(Return(OkStatus()));
1434 EXPECT_CALL(*overestimating_loader, EstimateResources(_))
1435 .WillOnce(Invoke([](ResourceAllocation* estimate) {
1436 *estimate = CreateResourceQuantity(5 );
1439 .RetiresOnSaturation();
1441 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1442 overestimating_id, std::unique_ptr<Loader>(overestimating_loader))));
1443 Notification overestimating_loaded;
1444 basic_manager_->LoadServable(overestimating_id,
1445 [&overestimating_loaded](
const Status& status) {
1446 TF_EXPECT_OK(status);
1447 overestimating_loaded.Notify();
1449 overestimating_loaded.WaitForNotification();
1454 const ServableId succeeding_id = {
"succeeding", 0};
1455 test_util::MockLoader* succeeding_loader =
1456 new NiceMock<test_util::MockLoader>;
1457 ON_CALL(*succeeding_loader, EstimateResources(_))
1458 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1459 *estimate = CreateResourceQuantity(5);
1462 EXPECT_CALL(*succeeding_loader,
1463 LoadWithMetadata(Loader::Metadata{succeeding_id}))
1464 .WillOnce(Return(OkStatus()));
1465 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1466 succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1467 basic_manager_->LoadServable(
1468 succeeding_id, [](
const Status& status) { TF_EXPECT_OK(status); });
1471 TEST_F(ResourceConstrainedBasicManagerTest, ResourcesReleasedAfterUnload) {
1472 const ServableId unloading_id = {
"unloading", 0};
1473 test_util::MockLoader* unloading_loader =
new NiceMock<test_util::MockLoader>;
1474 ON_CALL(*unloading_loader, EstimateResources(_))
1475 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1476 *estimate = CreateResourceQuantity(10);
1479 Notification load_done;
1480 EXPECT_CALL(*unloading_loader,
1481 LoadWithMetadata(Loader::Metadata{unloading_id}))
1482 .WillOnce(Return(OkStatus()));
1483 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1484 unloading_id, std::unique_ptr<Loader>(unloading_loader))));
1485 basic_manager_->LoadServable(unloading_id,
1486 [&load_done](
const Status& status) {
1487 TF_EXPECT_OK(status);
1490 load_done.WaitForNotification();
1491 Notification unload_started;
1492 Notification finish_unload;
1493 EXPECT_CALL(*unloading_loader, Unload())
1494 .WillOnce(Invoke([&unload_started, &finish_unload] {
1495 unload_started.Notify();
1496 finish_unload.WaitForNotification();
1498 basic_manager_->UnloadServable(
1499 unloading_id, [](
const Status& status) { TF_EXPECT_OK(status); });
1500 unload_started.WaitForNotification();
1505 const ServableId succeeding_id = {
"succeeding", 0};
1506 test_util::MockLoader* succeeding_loader =
1507 new NiceMock<test_util::MockLoader>;
1508 EXPECT_CALL(*succeeding_loader, EstimateResources(_))
1509 .WillOnce(Invoke([&finish_unload](ResourceAllocation* estimate) {
1510 finish_unload.Notify();
1511 *estimate = CreateResourceQuantity(10);
1514 .WillOnce(Invoke([](ResourceAllocation* estimate) {
1515 *estimate = CreateResourceQuantity(10);
1518 EXPECT_CALL(*succeeding_loader,
1519 LoadWithMetadata(Loader::Metadata{succeeding_id}))
1520 .WillOnce(Return(OkStatus()));
1521 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1522 succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1523 basic_manager_->LoadServable(
1524 succeeding_id, [](
const Status& status) { TF_EXPECT_OK(status); });
1527 basic_manager_.reset();
1530 TEST_F(ResourceConstrainedBasicManagerTest, FirstLoadDeniedSecondOneApproved) {
1532 const ServableId denied_id = {
"denied", 0};
1533 test_util::MockLoader* denied_loader =
new NiceMock<test_util::MockLoader>;
1534 Notification denied_estimate_started;
1535 Notification finish_denied_estimate;
1536 EXPECT_CALL(*denied_loader, EstimateResources(_))
1537 .WillOnce(Invoke([&denied_estimate_started,
1538 &finish_denied_estimate](ResourceAllocation* estimate) {
1539 denied_estimate_started.Notify();
1540 finish_denied_estimate.WaitForNotification();
1541 *estimate = CreateResourceQuantity(11 );
1545 EXPECT_CALL(*denied_loader, LoadWithMetadata(Loader::Metadata{denied_id}))
1547 TF_ASSERT_OK(basic_manager_->ManageServable(
1548 CreateServableData(denied_id, std::unique_ptr<Loader>(denied_loader))));
1551 const ServableId succeeding_id = {
"succeeding", 0};
1552 test_util::MockLoader* succeeding_loader =
1553 new NiceMock<test_util::MockLoader>;
1554 ON_CALL(*succeeding_loader, EstimateResources(_))
1555 .WillByDefault(Invoke([](ResourceAllocation* estimate) {
1556 *estimate = CreateResourceQuantity(10);
1559 TF_ASSERT_OK(basic_manager_->ManageServable(CreateServableData(
1560 succeeding_id, std::unique_ptr<Loader>(succeeding_loader))));
1562 Status denied_load_status;
1564 basic_manager_->LoadServable(
1565 denied_id, [&denied_load_status](
const Status& status) {
1566 denied_load_status = status;
1567 ASSERT_FALSE(status.ok());
1568 EXPECT_EQ(error::RESOURCE_EXHAUSTED, status.code());
1570 denied_estimate_started.WaitForNotification();
1573 EXPECT_CALL(*succeeding_loader,
1574 LoadWithMetadata(Loader::Metadata{succeeding_id}))
1575 .WillOnce(InvokeWithoutArgs([&finish_denied_estimate]() {
1578 EXPECT_TRUE(finish_denied_estimate.HasBeenNotified());
1588 std::unique_ptr<Thread> load_servable(
1589 Env::Default()->StartThread({},
"LoadServable", [&]() {
1590 basic_manager_->LoadServable(succeeding_id, [](
const Status& status) {
1591 TF_EXPECT_OK(status);
1595 finish_denied_estimate.Notify();
1599 basic_manager_.reset();
1601 const ServableState expected_error_state = {
1602 denied_id, ServableState::ManagerState::kEnd, denied_load_status};
1603 EXPECT_THAT(*servable_state_monitor_.
GetState(denied_id),
1604 EqualsServableState(expected_error_state));
1607 TEST_F(ResourceConstrainedBasicManagerTest, EventBusErrorOnEstimateResources) {
1608 const ServableId
id = {kServableName, 7};
1609 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
1610 EXPECT_CALL(*loader, EstimateResources(_))
1611 .WillOnce(Return(errors::Internal(
"Error on estimate resources.")));
1612 TF_ASSERT_OK(basic_manager_->ManageServable(
1613 CreateServableData(
id, std::unique_ptr<Loader>(loader))));
1614 basic_manager_->LoadServable(
1615 id, [](
const Status& status) { EXPECT_FALSE(status.ok()); });
1616 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1617 {ServableState::ManagerState::kEnd});
1618 const ServableState error_state = {
1619 id, ServableState::ManagerState::kEnd,
1620 errors::Internal(strings::StrCat(
1621 "Error while attempting to reserve resources to load servable ",
1622 id.DebugString(),
": Error on estimate resources."))};
1623 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
1624 EqualsServableState(error_state));
1627 TEST(EstimateResourcesRetriedTest, Succeeds) {
1628 std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1630 ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1632 BasicManager::Options options;
1634 options.resource_tracker = CreateSimpleResourceTracker(10);
1635 options.servable_event_bus = servable_event_bus.get();
1636 options.num_load_threads = 0;
1637 options.num_unload_threads = 0;
1639 options.max_num_load_retries = 1;
1640 options.load_retry_interval_micros = 0;
1642 std::unique_ptr<BasicManager> basic_manager;
1643 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1645 const ServableId
id = {kServableName, 7};
1646 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
1647 EXPECT_CALL(*loader, EstimateResources(_))
1648 .WillOnce(Return(errors::Internal(
"Error on estimate resources.")))
1649 .WillOnce(Return(OkStatus()));
1650 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1651 .WillRepeatedly(Return(OkStatus()));
1652 TF_ASSERT_OK(basic_manager->ManageServable(
1653 CreateServableData(
id, std::unique_ptr<Loader>(loader))));
1654 basic_manager->LoadServable(
1655 id, [](
const Status& status) { EXPECT_TRUE(status.ok()); });
1656 WaitUntilServableManagerStateIsOneOf(
1657 servable_state_monitor,
id, {ServableState::ManagerState::kAvailable});
1658 const ServableState available_state = {
1659 id, ServableState::ManagerState::kAvailable, OkStatus()};
1660 EXPECT_THAT(*servable_state_monitor.GetState(
id),
1661 EqualsServableState(available_state));
1664 TEST(EstimateResourcesRetriedTest, Fails) {
1665 std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1667 ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1669 BasicManager::Options options;
1671 options.resource_tracker = CreateSimpleResourceTracker(10);
1672 options.servable_event_bus = servable_event_bus.get();
1673 options.num_load_threads = 0;
1674 options.num_unload_threads = 0;
1676 options.max_num_load_retries = 1;
1677 options.load_retry_interval_micros = 0;
1679 std::unique_ptr<BasicManager> basic_manager;
1680 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1682 const ServableId
id = {kServableName, 7};
1683 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
1684 EXPECT_CALL(*loader, EstimateResources(_))
1685 .WillOnce(Return(errors::Internal(
"Error on estimate resources.")))
1686 .WillOnce(Return(errors::Internal(
"Error on estimate resources.")))
1687 .WillRepeatedly(Return(OkStatus()));
1688 TF_ASSERT_OK(basic_manager->ManageServable(
1689 CreateServableData(
id, std::unique_ptr<Loader>(loader))));
1690 basic_manager->LoadServable(
1691 id, [](
const Status& status) { EXPECT_FALSE(status.ok()); });
1692 WaitUntilServableManagerStateIsOneOf(servable_state_monitor,
id,
1693 {ServableState::ManagerState::kEnd});
1694 EXPECT_FALSE(servable_state_monitor.GetState(
id)->health.ok());
1697 TEST(EstimateResourcesRetriedTest, NonRetriableError) {
1698 std::shared_ptr<EventBus<ServableState>> servable_event_bus =
1700 ServableStateMonitor servable_state_monitor(servable_event_bus.get());
1702 BasicManager::Options options;
1704 options.resource_tracker = CreateSimpleResourceTracker(10);
1705 options.servable_event_bus = servable_event_bus.get();
1706 options.num_load_threads = 0;
1707 options.num_unload_threads = 0;
1708 options.should_retry_model_load =
1709 ([](absl::Status status) {
return !absl::IsInvalidArgument(status); });
1711 options.max_num_load_retries = 10;
1712 options.load_retry_interval_micros = 100000000;
1714 std::unique_ptr<BasicManager> basic_manager;
1715 TF_CHECK_OK(BasicManager::Create(std::move(options), &basic_manager));
1717 const ServableId
id = {kServableName, 7};
1718 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
1719 EXPECT_CALL(*loader, LoadWithMetadata(_))
1720 .WillOnce(Return(errors::InvalidArgument(
"Non-retriable error.")))
1721 .WillRepeatedly(Return(absl::OkStatus()));
1722 TF_ASSERT_OK(basic_manager->ManageServable(
1723 CreateServableData(
id, std::unique_ptr<Loader>(loader))));
1724 basic_manager->LoadServable(
1725 id, [](
const auto& status) { EXPECT_FALSE(status.ok()); });
1728 WaitUntilServableManagerStateIsOneOf(
1729 servable_state_monitor,
id,
1730 {ServableState::ManagerState::kEnd,
1731 ServableState::ManagerState::kAvailable});
1732 const auto final_state = servable_state_monitor.GetState(
id);
1733 ASSERT_TRUE(final_state.has_value());
1734 EXPECT_EQ(final_state->manager_state, ServableState::ManagerState::kEnd);
1735 EXPECT_FALSE(final_state->health.ok());
1736 EXPECT_EQ(final_state->health.message(),
"Non-retriable error.");
static std::shared_ptr< EventBus > CreateEventBus(const Options &options={})
@ kReady
'loader_->Load()' has succeeded.
@ kDisabled
'loader_->Unload()' has finished.
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)