16 #include "tensorflow_serving/core/aspired_versions_manager.h"
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 #include "tensorflow_serving/core/aspired_version_policy.h"
36 #include "tensorflow_serving/core/availability_preserving_policy.h"
37 #include "tensorflow_serving/core/servable_state_monitor.h"
38 #include "tensorflow_serving/core/test_util/availability_test_util.h"
39 #include "tensorflow_serving/core/test_util/fake_loader.h"
40 #include "tensorflow_serving/core/test_util/manager_test_util.h"
41 #include "tensorflow_serving/core/test_util/mock_loader.h"
42 #include "tensorflow_serving/util/any_ptr.h"
43 #include "tensorflow_serving/util/event_bus.h"
45 namespace tensorflow {
49 using test_util::FakeLoader;
50 using test_util::WaitUntilServableManagerStateIsOneOf;
52 using ::testing::Invoke;
53 using ::testing::InvokeWithoutArgs;
54 using ::testing::NiceMock;
55 using ::testing::Return;
56 using ::testing::UnorderedElementsAre;
57 using ::testing::UnorderedElementsAreArray;
59 constexpr
char kServableName[] =
"kServableName";
60 constexpr
char kServableName2[] =
"kServableName2";
61 constexpr
int kNumVersionsPerServable = 2;
62 constexpr
int kNumTotalVersions = 4;
66 ServableData<std::unique_ptr<Loader>> CreateAspiredVersion(
67 const ServableId&
id) {
68 std::unique_ptr<Loader> loader(
new FakeLoader(
id.version));
69 return CreateServableData(
id, std::move(loader));
74 struct ThreadPoolSizes {
75 uint64_t num_load_threads;
76 uint64_t num_unload_threads;
78 class AspiredVersionsManagerTest
79 :
public ::testing::TestWithParam<std::tuple<ThreadPoolSizes, bool>> {
81 AspiredVersionsManagerTest()
82 : servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
83 servable_state_monitor_(servable_event_bus_.get()),
84 thread_pool_sizes_(std::get<0>(GetParam())),
85 enable_reload_servables_with_error_(std::get<1>(GetParam())) {
86 AspiredVersionsManager::Options manager_options;
87 manager_options.num_load_threads = thread_pool_sizes_.num_load_threads;
88 manager_options.num_unload_threads = thread_pool_sizes_.num_unload_threads;
90 manager_options.manage_state_interval_micros = -1;
91 manager_options.env = Env::Default();
92 manager_options.aspired_version_policy.reset(
93 new AvailabilityPreservingPolicy());
94 manager_options.servable_event_bus = servable_event_bus_.get();
95 max_num_load_retries_ = 1;
96 manager_options.max_num_load_retries = max_num_load_retries_;
97 manager_options.load_retry_interval_micros = 0;
98 manager_options.enable_reload_servables_with_error =
99 enable_reload_servables_with_error_;
101 AspiredVersionsManager::Create(std::move(manager_options), &manager_));
105 ServableData<std::unique_ptr<Loader>> CreateErroneousAspiredVersion(
106 const ServableId&
id) {
107 return ServableData<std::unique_ptr<Loader>>(id, errors::Unknown(
"error"));
110 void SetUp()
override {
113 std::set<ServableId> servables;
114 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
115 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
116 const ServableId
id = {kServableName, i};
117 aspired_versions.push_back(CreateAspiredVersion(
id));
118 servables.insert(
id);
120 manager_->GetAspiredVersionsCallback()(kServableName,
121 std::move(aspired_versions));
122 HandlePendingAspiredVersionsRequests();
124 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions2;
125 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
126 const ServableId
id = {kServableName2, i};
127 aspired_versions2.push_back(CreateAspiredVersion(
id));
128 servables.insert(
id);
130 manager_->GetAspiredVersionsCallback()(kServableName2,
131 std::move(aspired_versions2));
132 HandlePendingAspiredVersionsRequests();
134 for (
int i = 0; i < kNumTotalVersions; ++i) {
137 InvokePolicyAndExecuteAction();
139 for (
const ServableId& servable : servables) {
140 WaitUntilServableManagerStateIsOneOf(
141 servable_state_monitor_, servable,
142 {ServableState::ManagerState::kAvailable});
146 void FlushServables() {
147 test_util::AspiredVersionsManagerTestAccess(manager_.get())
151 void HandlePendingAspiredVersionsRequests() {
152 test_util::AspiredVersionsManagerTestAccess(manager_.get())
153 .HandlePendingAspiredVersionsRequests();
156 void InvokePolicyAndExecuteAction() {
157 test_util::AspiredVersionsManagerTestAccess(manager_.get())
158 .InvokePolicyAndExecuteAction();
161 std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
162 ServableStateMonitor servable_state_monitor_;
163 ThreadPoolSizes thread_pool_sizes_;
164 uint32 max_num_load_retries_;
165 bool enable_reload_servables_with_error_;
166 std::unique_ptr<AspiredVersionsManager> manager_;
169 INSTANTIATE_TEST_CASE_P(
170 WithOrWithoutThreadPools, AspiredVersionsManagerTest,
173 std::make_tuple(ThreadPoolSizes{0, 0},
false),
175 std::make_tuple(ThreadPoolSizes{2, 0},
false),
177 std::make_tuple(ThreadPoolSizes{0, 2},
false),
179 std::make_tuple(ThreadPoolSizes{4, 4},
false),
181 std::make_tuple(ThreadPoolSizes{0, 0},
true),
183 std::make_tuple(ThreadPoolSizes{4, 4},
true)));
185 TEST_P(AspiredVersionsManagerTest, ServableHandleNotFoundMissingLoaderName) {
186 ServableHandle<int64_t> handle;
187 const Status status = manager_->GetServableHandle(
188 ServableRequest::Latest(strings::StrCat(kServableName,
"missing")),
190 ASSERT_FALSE(status.ok()) << status;
191 EXPECT_EQ(error::NOT_FOUND, status.code());
194 TEST_P(AspiredVersionsManagerTest, ServableHandleNotFoundMissingVersion) {
196 const int64_t missing_version = 100;
197 ServableHandle<int64_t> handle;
198 const Status status = manager_->GetServableHandle(
199 ServableRequest::Specific(kServableName, missing_version), &handle);
200 ASSERT_FALSE(status.ok()) << status;
201 EXPECT_EQ(error::NOT_FOUND, status.code());
204 TEST_P(AspiredVersionsManagerTest, ServableHandleInvalidArgument) {
207 ServableHandle<float> handle;
208 const Status status = manager_->GetServableHandle(
209 ServableRequest::Latest(kServableName), &handle);
210 ASSERT_FALSE(status.ok()) << status;
211 EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
214 TEST_P(AspiredVersionsManagerTest, ServableHandleLatest) {
215 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
216 const ServableId
id = {kServableName, kNumVersionsPerServable + 1};
217 aspired_versions.push_back(CreateAspiredVersion(
id));
218 manager_->GetAspiredVersionsCallback()(kServableName,
219 std::move(aspired_versions));
220 HandlePendingAspiredVersionsRequests();
223 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
224 InvokePolicyAndExecuteAction();
226 WaitUntilServableManagerStateIsOneOf(
227 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
229 ServableHandle<int64_t> handle;
230 const Status status = manager_->GetServableHandle(
231 ServableRequest::Latest(kServableName), &handle);
232 TF_ASSERT_OK(status);
233 EXPECT_EQ(kNumVersionsPerServable + 1, *handle);
237 TEST_P(AspiredVersionsManagerTest, ServableHandleLatestVersionIsZero) {
238 const char kServableName3[] =
"kServableName3";
240 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
241 const ServableId
id = {kServableName3, 0};
242 aspired_versions.push_back(CreateAspiredVersion(
id));
243 manager_->GetAspiredVersionsCallback()(kServableName3,
244 std::move(aspired_versions));
245 HandlePendingAspiredVersionsRequests();
247 InvokePolicyAndExecuteAction();
248 WaitUntilServableManagerStateIsOneOf(
249 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
251 ServableHandle<int64_t> handle;
252 const Status status = manager_->GetServableHandle(
253 ServableRequest::Latest(kServableName3), &handle);
254 TF_ASSERT_OK(status);
255 EXPECT_EQ(0, *handle);
256 EXPECT_EQ(
id, handle.id());
259 TEST_P(AspiredVersionsManagerTest, ReloadAspiredError) {
260 const char kServableName[] =
"kAspiredError";
261 auto callback_fn = manager_->GetAspiredVersionsCallback();
264 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
265 const ServableId
id = {kServableName, 1};
266 aspired_versions.push_back(CreateAspiredVersion(
id));
267 callback_fn(kServableName, std::move(aspired_versions));
268 HandlePendingAspiredVersionsRequests();
269 InvokePolicyAndExecuteAction();
270 WaitUntilServableManagerStateIsOneOf(
271 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
272 ServableHandle<int64_t> handle;
273 const Status status = manager_->GetServableHandle(
274 ServableRequest::Latest(kServableName), &handle);
275 TF_ASSERT_OK(status);
276 EXPECT_EQ(1, *handle);
277 EXPECT_EQ(
id, handle.id());
281 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
282 const ServableId
id = {kServableName, 2};
283 aspired_versions.push_back(CreateErroneousAspiredVersion(
id));
284 callback_fn(kServableName, std::move(aspired_versions));
285 HandlePendingAspiredVersionsRequests();
286 InvokePolicyAndExecuteAction();
287 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
288 {ServableState::ManagerState::kEnd});
289 ServableHandle<int64_t> handle;
290 Status status = manager_->GetServableHandle(
291 ServableRequest::Specific(kServableName, 2), &handle);
292 EXPECT_FALSE(status.ok()) << status;
296 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
297 const ServableId
id = {kServableName, 2};
298 aspired_versions.push_back(CreateAspiredVersion(
id));
299 callback_fn(kServableName, std::move(aspired_versions));
300 HandlePendingAspiredVersionsRequests();
301 InvokePolicyAndExecuteAction();
302 if (enable_reload_servables_with_error_) {
303 WaitUntilServableManagerStateIsOneOf(
304 servable_state_monitor_,
id,
305 {ServableState::ManagerState::kAvailable});
306 ServableHandle<int64_t> handle;
307 Status status = manager_->GetServableHandle(
308 ServableRequest::Specific(kServableName, 2), &handle);
309 TF_ASSERT_OK(status) << status;
312 Env::Default()->SleepForMicroseconds(1000 );
313 ServableHandle<int64_t> handle;
314 Status status = manager_->GetServableHandle(
315 ServableRequest::Specific(kServableName, 2), &handle);
316 EXPECT_FALSE(status.ok()) << status;
321 TEST_P(AspiredVersionsManagerTest, ServableHandleSpecificVersion) {
322 ServableHandle<int64_t> handle;
323 const ServableId
id = {kServableName2, 0};
324 const Status status =
325 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle);
326 TF_ASSERT_OK(status);
327 EXPECT_EQ(0, *handle);
328 EXPECT_EQ(
id, handle.id());
331 TEST_P(AspiredVersionsManagerTest, ListAvailableServableIds) {
332 const std::vector<ServableId> expected_before = {{kServableName, 0},
335 {kServableName2, 1}};
336 EXPECT_THAT(manager_->ListAvailableServableIds(),
337 UnorderedElementsAreArray(expected_before));
342 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
343 const ServableId
id = {kServableName, 7};
344 std::unique_ptr<Loader> loader(
345 new FakeLoader(7, errors::Internal(
"An error.")));
346 aspired_versions.push_back({id, std::move(loader)});
347 manager_->GetAspiredVersionsCallback()(kServableName,
348 std::move(aspired_versions));
349 HandlePendingAspiredVersionsRequests();
350 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
351 InvokePolicyAndExecuteAction();
353 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
354 {ServableState::ManagerState::kEnd});
356 manager_->GetAspiredVersionsCallback()(kServableName, {});
357 HandlePendingAspiredVersionsRequests();
358 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
359 InvokePolicyAndExecuteAction();
361 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
363 {ServableState::ManagerState::kEnd});
364 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
366 {ServableState::ManagerState::kEnd});
368 const std::vector<ServableId> expected_after = {{kServableName2, 0},
369 {kServableName2, 1}};
370 EXPECT_THAT(manager_->ListAvailableServableIds(),
371 UnorderedElementsAreArray(expected_after));
374 TEST_P(AspiredVersionsManagerTest, GetAvailableServableHandles) {
377 const std::map<ServableId, ServableHandle<int64_t>> handles_before =
378 manager_->GetAvailableServableHandles<int64_t>();
379 ASSERT_EQ(kNumVersionsPerServable * 2, handles_before.size());
381 const std::vector<ServableId> expected_ids_before = {{kServableName, 0},
384 {kServableName2, 1}};
385 for (
const ServableId& expected_id : expected_ids_before) {
386 const auto found_it = handles_before.find(expected_id);
387 ASSERT_TRUE(found_it != handles_before.end());
388 EXPECT_EQ(expected_id.version, *found_it->second);
395 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
396 const ServableId
id = {kServableName, 7};
397 std::unique_ptr<Loader> loader(
398 new FakeLoader(7, errors::Internal(
"An error.")));
399 aspired_versions.push_back({id, std::move(loader)});
400 manager_->GetAspiredVersionsCallback()(kServableName,
401 std::move(aspired_versions));
402 HandlePendingAspiredVersionsRequests();
403 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
404 InvokePolicyAndExecuteAction();
406 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
407 {ServableState::ManagerState::kEnd});
409 manager_->GetAspiredVersionsCallback()(kServableName, {});
410 HandlePendingAspiredVersionsRequests();
411 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
412 InvokePolicyAndExecuteAction();
414 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
416 {ServableState::ManagerState::kEnd});
417 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
419 {ServableState::ManagerState::kEnd});
421 const std::map<ServableId, ServableHandle<int64_t>> handles_after =
422 manager_->GetAvailableServableHandles<int64_t>();
423 ASSERT_EQ(kNumVersionsPerServable, handles_after.size());
425 const std::vector<ServableId> expected_ids_after = {{kServableName2, 0},
426 {kServableName2, 1}};
427 for (
const ServableId& expected_id : expected_ids_after) {
428 const auto found_it = handles_after.find(expected_id);
429 ASSERT_TRUE(found_it != handles_after.end());
430 EXPECT_EQ(expected_id.version, *found_it->second);
435 TEST_P(AspiredVersionsManagerTest, GetAvailableServableHandlesWrongType) {
436 const std::map<ServableId, ServableHandle<int>> wrong_type_handles =
437 manager_->GetAvailableServableHandles<
int>();
438 EXPECT_EQ(0, wrong_type_handles.size());
441 TEST_P(AspiredVersionsManagerTest, AspiredRemovedFull) {
445 ServableHandle<int64_t> handle;
446 const Status status = manager_->GetServableHandle(
447 ServableRequest::Latest(kServableName), &handle);
448 TF_ASSERT_OK(status);
449 EXPECT_EQ(1, *handle);
452 manager_->GetAspiredVersionsCallback()(kServableName, {});
453 HandlePendingAspiredVersionsRequests();
455 const int num_fake_loaders_before = FakeLoader::num_fake_loaders();
456 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
457 InvokePolicyAndExecuteAction();
459 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
461 {ServableState::ManagerState::kEnd});
462 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
464 {ServableState::ManagerState::kEnd});
466 const int num_fake_loaders_after = FakeLoader::num_fake_loaders();
467 EXPECT_EQ(kNumVersionsPerServable,
468 num_fake_loaders_before - num_fake_loaders_after);
470 ServableHandle<int64_t> missing_handle;
471 const Status missing_status = manager_->GetServableHandle(
472 ServableRequest::Latest(kServableName), &missing_handle);
473 ASSERT_FALSE(missing_status.ok());
474 EXPECT_EQ(error::NOT_FOUND, missing_status.code());
477 TEST_P(AspiredVersionsManagerTest, AspiredRemovedPartial) {
478 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
479 aspired_versions.push_back(CreateAspiredVersion({kServableName, 0}));
480 manager_->GetAspiredVersionsCallback()(kServableName,
481 std::move(aspired_versions));
482 HandlePendingAspiredVersionsRequests();
484 InvokePolicyAndExecuteAction();
485 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
487 {ServableState::ManagerState::kEnd});
490 ServableHandle<int64_t> v0_handle;
491 const Status v0_status = manager_->GetServableHandle(
492 ServableRequest::Specific(kServableName, 0), &v0_handle);
493 TF_ASSERT_OK(v0_status);
494 EXPECT_EQ(0, *v0_handle);
497 ServableHandle<int64_t> v1_handle;
498 const Status v1_status = manager_->GetServableHandle(
499 ServableRequest::Specific(kServableName, 1), &v1_handle);
500 ASSERT_FALSE(v1_status.ok());
501 EXPECT_EQ(error::NOT_FOUND, v1_status.code());
504 TEST_P(AspiredVersionsManagerTest, RevertToSmallerVersionNumber) {
506 std::set<int64_t> initial_versions;
507 for (
const ServableId&
id : manager_->ListAvailableServableIds()) {
508 if (
id.name == kServableName) {
509 initial_versions.insert(
id.version);
512 ASSERT_THAT(initial_versions, UnorderedElementsAre(0, 1));
515 std::vector<ServableData<std::unique_ptr<Loader>>> initial_aspired_versions;
516 initial_aspired_versions.push_back(CreateAspiredVersion({kServableName, 1}));
517 manager_->GetAspiredVersionsCallback()(kServableName,
518 std::move(initial_aspired_versions));
519 HandlePendingAspiredVersionsRequests();
520 InvokePolicyAndExecuteAction();
521 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
523 {ServableState::ManagerState::kEnd});
527 std::vector<ServableData<std::unique_ptr<Loader>>> new_aspired_versions;
528 new_aspired_versions.push_back(CreateAspiredVersion({kServableName, 0}));
529 manager_->GetAspiredVersionsCallback()(kServableName,
530 std::move(new_aspired_versions));
531 HandlePendingAspiredVersionsRequests();
532 Notification done_transitioning;
533 std::unique_ptr<Thread> transition_servables(
534 Env::Default()->StartThread({},
"TransitionServables", [&]() {
535 while (!done_transitioning.HasBeenNotified()) {
536 InvokePolicyAndExecuteAction();
537 Env::Default()->SleepForMicroseconds(1000 );
540 WaitUntilServableManagerStateIsOneOf(
541 servable_state_monitor_, {kServableName, 0},
542 {ServableState::ManagerState::kAvailable});
543 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
545 {ServableState::ManagerState::kEnd});
546 done_transitioning.Notify();
549 ServableHandle<int64_t> v0_handle;
550 const Status v0_status = manager_->GetServableHandle(
551 ServableRequest::Specific(kServableName, 0), &v0_handle);
552 TF_ASSERT_OK(v0_status);
553 EXPECT_EQ(0, *v0_handle);
556 ServableHandle<int64_t> v1_handle;
557 const Status v1_status = manager_->GetServableHandle(
558 ServableRequest::Specific(kServableName, 1), &v1_handle);
559 ASSERT_FALSE(v1_status.ok());
560 EXPECT_EQ(error::NOT_FOUND, v1_status.code());
563 TEST_P(AspiredVersionsManagerTest, AspiredAndManageStateLoad) {
564 const ServableId
id = {kServableName, 2};
565 ServableHandle<int64_t> not_found_handle;
566 const Status not_found_status = manager_->GetServableHandle(
567 ServableRequest::FromId(
id), ¬_found_handle);
568 ASSERT_FALSE(not_found_status.ok()) << not_found_status;
569 EXPECT_EQ(error::NOT_FOUND, not_found_status.code());
571 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
572 aspired_versions.push_back(CreateAspiredVersion(
id));
573 manager_->GetAspiredVersionsCallback()(kServableName,
574 std::move(aspired_versions));
575 HandlePendingAspiredVersionsRequests();
577 ServableHandle<int64_t> not_ready_handle;
578 const Status not_ready_status = manager_->GetServableHandle(
579 ServableRequest::FromId(
id), ¬_ready_handle);
580 ASSERT_FALSE(not_ready_status.ok()) << not_ready_status;
581 EXPECT_EQ(error::NOT_FOUND, not_ready_status.code());
585 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
586 InvokePolicyAndExecuteAction();
588 WaitUntilServableManagerStateIsOneOf(
589 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
591 ServableHandle<int64_t> handle;
592 const Status status =
593 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle);
594 TF_ASSERT_OK(status);
595 EXPECT_EQ(2, *handle);
598 TEST_P(AspiredVersionsManagerTest, AspiredAndManageStateUnload) {
600 ServableHandle<int64_t> handle;
601 const Status status = manager_->GetServableHandle(
602 ServableRequest::Specific(kServableName, 0), &handle);
603 TF_ASSERT_OK(status);
604 EXPECT_EQ(0, *handle);
607 manager_->GetAspiredVersionsCallback()(kServableName, {});
608 HandlePendingAspiredVersionsRequests();
610 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
611 InvokePolicyAndExecuteAction();
613 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
615 {ServableState::ManagerState::kEnd});
616 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
618 {ServableState::ManagerState::kEnd});
620 ServableHandle<int64_t> not_found_handle;
621 const Status not_found_status = manager_->GetServableHandle(
622 ServableRequest::Specific(kServableName, 0), ¬_found_handle);
623 ASSERT_FALSE(not_found_status.ok()) << not_found_status;
624 EXPECT_EQ(error::NOT_FOUND, not_found_status.code());
629 TEST_P(AspiredVersionsManagerTest, ManagerPrefersUnloadOverLoad) {
630 ServableHandle<int64_t> not_found_2_handle;
631 Status not_found_2_status = manager_->GetServableHandle(
632 ServableRequest::Specific(kServableName2, 2), ¬_found_2_handle);
633 ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
634 EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
637 ServableHandle<int64_t> found_0_handle;
638 const Status found_0_status = manager_->GetServableHandle(
639 ServableRequest::Specific(kServableName, 0), &found_0_handle);
640 TF_ASSERT_OK(found_0_status);
641 EXPECT_EQ(0, *found_0_handle);
650 } servable_aspired_list[2] = {{kServableName, 1, 1}, {kServableName2, 0, 2}};
651 for (
const auto& servable_aspired : servable_aspired_list) {
652 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
653 for (
int i = servable_aspired.start; i <= servable_aspired.end; ++i) {
654 const ServableId
id = {string(servable_aspired.name), i};
655 aspired_versions.push_back(CreateAspiredVersion(
id));
657 manager_->GetAspiredVersionsCallback()(servable_aspired.name,
658 std::move(aspired_versions));
659 HandlePendingAspiredVersionsRequests();
664 InvokePolicyAndExecuteAction();
665 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
667 {ServableState::ManagerState::kEnd});
669 ServableHandle<int64_t> not_found_0_handle;
670 const Status not_found_0_status = manager_->GetServableHandle(
671 ServableRequest::Specific(kServableName, 0), ¬_found_0_handle);
672 ASSERT_FALSE(not_found_0_status.ok()) << not_found_0_status;
673 EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
675 not_found_2_status = manager_->GetServableHandle(
676 ServableRequest::Specific(kServableName2, 2), ¬_found_2_handle);
677 ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
678 EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
681 InvokePolicyAndExecuteAction();
682 WaitUntilServableManagerStateIsOneOf(
683 servable_state_monitor_, {kServableName2, 2},
684 {ServableState::ManagerState::kAvailable});
686 ServableHandle<int64_t> found_2_handle;
687 const Status found_2_status = manager_->GetServableHandle(
688 ServableRequest::Specific(kServableName2, 2), &found_2_handle);
689 TF_ASSERT_OK(found_2_status);
690 EXPECT_EQ(2, *found_2_handle);
693 TEST_P(AspiredVersionsManagerTest, CustomSortActions) {
694 test_util::AspiredVersionsManagerTestAccess(manager_.get())
695 .SetCustomSortActions(
696 [](
const AspiredVersionPolicy::ServableAction& lhs,
697 const AspiredVersionPolicy::ServableAction& rhs) ->
bool {
700 bool lhs_is_servable_2 = lhs.id.name == kServableName2;
701 bool rhs_is_servable_2 = rhs.id.name == kServableName2;
702 if (lhs_is_servable_2 != rhs_is_servable_2) {
703 return lhs_is_servable_2;
709 ServableHandle<int64_t> not_found_2_handle;
710 Status not_found_2_status = manager_->GetServableHandle(
711 ServableRequest::Specific(kServableName2, 2), ¬_found_2_handle);
712 ASSERT_FALSE(not_found_2_status.ok()) << not_found_2_status;
713 EXPECT_EQ(error::NOT_FOUND, not_found_2_status.code());
717 ServableHandle<int64_t> found_0_handle;
718 Status found_0_status = manager_->GetServableHandle(
719 ServableRequest::Specific(kServableName, 0), &found_0_handle);
720 TF_ASSERT_OK(found_0_status);
721 EXPECT_EQ(0, *found_0_handle);
730 } servable_aspired_list[2] = {{kServableName, 1, 1}, {kServableName2, 0, 2}};
731 for (
const auto& servable_aspired : servable_aspired_list) {
732 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
733 for (
int i = servable_aspired.start; i <= servable_aspired.end; ++i) {
734 const ServableId
id = {string(servable_aspired.name), i};
735 aspired_versions.push_back(CreateAspiredVersion(
id));
737 manager_->GetAspiredVersionsCallback()(servable_aspired.name,
738 std::move(aspired_versions));
739 HandlePendingAspiredVersionsRequests();
745 InvokePolicyAndExecuteAction();
746 WaitUntilServableManagerStateIsOneOf(
747 servable_state_monitor_, {kServableName2, 2},
748 {ServableState::ManagerState::kAvailable});
751 ServableHandle<int64_t> found_0_handle;
752 Status found_0_status = manager_->GetServableHandle(
753 ServableRequest::Specific(kServableName, 0), &found_0_handle);
754 TF_ASSERT_OK(found_0_status);
755 EXPECT_EQ(0, *found_0_handle);
759 ServableHandle<int64_t> found_2_handle;
760 const Status found_2_status = manager_->GetServableHandle(
761 ServableRequest::Specific(kServableName2, 2), &found_2_handle);
762 TF_ASSERT_OK(found_2_status);
763 EXPECT_EQ(2, *found_2_handle);
767 InvokePolicyAndExecuteAction();
768 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
770 {ServableState::ManagerState::kEnd});
773 ServableHandle<int64_t> not_found_0_handle;
774 const Status not_found_0_status = manager_->GetServableHandle(
775 ServableRequest::Specific(kServableName, 0), ¬_found_0_handle);
776 ASSERT_FALSE(not_found_0_status.ok()) << not_found_0_status;
777 EXPECT_EQ(error::NOT_FOUND, not_found_0_status.code());
783 TEST_P(AspiredVersionsManagerTest, ErroneousAspiredVersion) {
784 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
785 aspired_versions.push_back(CreateErroneousAspiredVersion({kServableName, 3}));
786 manager_->GetAspiredVersionsCallback()(kServableName,
787 std::move(aspired_versions));
788 HandlePendingAspiredVersionsRequests();
790 ServableHandle<int64_t> handle;
791 Status status = manager_->GetServableHandle(
792 ServableRequest::Specific(kServableName, 3), &handle);
793 EXPECT_FALSE(status.ok()) << status;
795 InvokePolicyAndExecuteAction();
797 status = manager_->GetServableHandle(
798 ServableRequest::Specific(kServableName, 3), &handle);
799 EXPECT_FALSE(status.ok()) << status;
804 TEST_P(AspiredVersionsManagerTest, DestructOnNonServingThread) {
805 std::unique_ptr<ServableHandle<int64_t>> latest_handle(
806 new ServableHandle<int64_t>());
807 const Status status = manager_->GetServableHandle(
808 ServableRequest::Latest(kServableName), latest_handle.get());
809 TF_ASSERT_OK(status);
810 EXPECT_EQ(1, **latest_handle);
812 manager_->GetAspiredVersionsCallback()(kServableName, {});
813 HandlePendingAspiredVersionsRequests();
815 Notification done_unload_servable;
816 std::unique_ptr<Thread> unload_servable(
817 Env::Default()->StartThread({},
"UnloadServable", [&]() {
819 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
820 InvokePolicyAndExecuteAction();
822 WaitUntilServableManagerStateIsOneOf(
823 servable_state_monitor_, {kServableName, 0},
824 {ServableState::ManagerState::kEnd});
828 if (thread_pool_sizes_.num_unload_threads == 0) {
829 EXPECT_TRUE(FakeLoader::was_deleted_in_this_thread());
831 done_unload_servable.Notify();
835 latest_handle.reset();
836 done_unload_servable.WaitForNotification();
838 EXPECT_FALSE(FakeLoader::was_deleted_in_this_thread());
841 MATCHER_P(EqualsServableState, servable_state, servable_state.DebugString()) {
842 if (arg == servable_state) {
845 *result_listener << arg.DebugString();
849 TEST_P(AspiredVersionsManagerTest, EventBusErroneousVersion) {
850 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
851 const ServableId
id = {kServableName, 3};
852 aspired_versions.push_back(
853 ServableData<std::unique_ptr<Loader>>(
id, errors::Unknown(
"error")));
854 manager_->GetAspiredVersionsCallback()(kServableName,
855 std::move(aspired_versions));
856 HandlePendingAspiredVersionsRequests();
858 const ServableState expected_published_state = {
859 id, ServableState::ManagerState::kEnd, errors::Unknown(
"error")};
860 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
861 EqualsServableState(expected_published_state));
864 TEST_P(AspiredVersionsManagerTest, EventBusErrorOnLoad) {
865 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
866 const ServableId
id = {kServableName, 7};
867 std::unique_ptr<Loader> loader(
868 new FakeLoader(7, errors::Internal(
"Error on load.")));
869 aspired_versions.push_back({id, std::move(loader)});
870 manager_->GetAspiredVersionsCallback()(kServableName,
871 std::move(aspired_versions));
872 HandlePendingAspiredVersionsRequests();
874 const ServableState start_state = {id, ServableState::ManagerState::kStart,
876 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
877 EqualsServableState(start_state));
881 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
882 InvokePolicyAndExecuteAction();
884 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
885 {ServableState::ManagerState::kEnd});
887 const ServableState error_state = {id, ServableState::ManagerState::kEnd,
888 errors::Internal(
"Error on load.")};
889 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
890 EqualsServableState(error_state));
893 TEST_P(AspiredVersionsManagerTest, EventBusServableLifecycle) {
894 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
895 const ServableId
id = {kServableName, 7};
896 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>();
897 aspired_versions.push_back({id, std::unique_ptr<Loader>(loader)});
898 manager_->GetAspiredVersionsCallback()(kServableName,
899 std::move(aspired_versions));
900 HandlePendingAspiredVersionsRequests();
902 const ServableState start_state = {id, ServableState::ManagerState::kStart,
904 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
905 EqualsServableState(start_state));
907 Notification load_called;
908 Notification load_continue;
909 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
910 .WillOnce(InvokeWithoutArgs([&]() {
911 load_called.Notify();
912 load_continue.WaitForNotification();
916 std::unique_ptr<Thread> load_unload_thread(
917 Env::Default()->StartThread(ThreadOptions(),
"LoadUnloadThread", [&]() {
921 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
922 InvokePolicyAndExecuteAction();
926 load_called.WaitForNotification();
928 const ServableState loading_state = {
929 id, ServableState::ManagerState::kLoading, OkStatus()};
930 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
931 EqualsServableState(loading_state));
933 load_continue.Notify();
934 WaitUntilServableManagerStateIsOneOf(
935 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
937 const ServableState available_state = {
938 id, ServableState::ManagerState::kAvailable, OkStatus()};
939 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
940 EqualsServableState(available_state));
942 manager_->GetAspiredVersionsCallback()(kServableName, {});
943 HandlePendingAspiredVersionsRequests();
945 Notification unload_called;
946 Notification unload_continue;
947 EXPECT_CALL(*loader, Unload()).WillOnce(Invoke([&]() {
948 unload_called.Notify();
949 unload_continue.WaitForNotification();
952 std::unique_ptr<Thread> unload_thread(
953 Env::Default()->StartThread(ThreadOptions(),
"UnloadThread", [&]() {
956 InvokePolicyAndExecuteAction();
957 InvokePolicyAndExecuteAction();
960 unload_called.WaitForNotification();
962 const ServableState unloading_state = {
963 id, ServableState::ManagerState::kUnloading, OkStatus()};
964 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
965 EqualsServableState(unloading_state));
967 unload_continue.Notify();
968 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
969 {ServableState::ManagerState::kEnd});
971 const ServableState end_state = {
972 {kServableName, 7}, ServableState::ManagerState::kEnd, OkStatus()};
973 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
974 EqualsServableState(end_state));
978 TEST_P(AspiredVersionsManagerTest, NoEventBus) {
979 AspiredVersionsManager::Options options;
981 options.manage_state_interval_micros = -1;
982 options.env = Env::Default();
983 options.aspired_version_policy.reset(
new AvailabilityPreservingPolicy());
984 std::unique_ptr<AspiredVersionsManager> aspired_versions_manager;
985 TF_ASSERT_OK(AspiredVersionsManager::Create(std::move(options),
986 &aspired_versions_manager));
988 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
989 const ServableId
id = {kServableName, 7};
990 std::unique_ptr<Loader> loader(
new FakeLoader(7));
991 aspired_versions.push_back({id, std::move(loader)});
992 aspired_versions_manager->GetAspiredVersionsCallback()(
993 kServableName, std::move(aspired_versions));
994 HandlePendingAspiredVersionsRequests();
997 TEST_P(AspiredVersionsManagerTest, RetryOnLoadErrorFinallySucceeds) {
998 CHECK_GE(max_num_load_retries_, 1);
999 const ServableId
id = {kServableName, 7};
1000 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
1002 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{
id}))
1003 .WillOnce(Return(errors::Internal(
"Error on load.")))
1004 .WillOnce(Return(OkStatus()));
1006 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1007 aspired_versions.push_back({id, std::unique_ptr<Loader>(loader)});
1008 manager_->GetAspiredVersionsCallback()(kServableName,
1009 std::move(aspired_versions));
1010 HandlePendingAspiredVersionsRequests();
1014 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1015 InvokePolicyAndExecuteAction();
1017 WaitUntilServableManagerStateIsOneOf(
1018 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
1020 const ServableState available_state = {
1021 id, ServableState::ManagerState::kAvailable, OkStatus()};
1022 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
1023 EqualsServableState(available_state));
1026 TEST_P(AspiredVersionsManagerTest, RetryOnLoadErrorFinallyFails) {
1027 CHECK_GE(max_num_load_retries_, 1);
1029 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1030 const ServableId
id = {kServableName, 7};
1032 std::unique_ptr<Loader> loader(
1033 new FakeLoader(7, errors::Internal(
"Error on load.")));
1034 aspired_versions.push_back({id, std::move(loader)});
1035 manager_->GetAspiredVersionsCallback()(kServableName,
1036 std::move(aspired_versions));
1037 HandlePendingAspiredVersionsRequests();
1041 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1042 InvokePolicyAndExecuteAction();
1044 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1045 {ServableState::ManagerState::kEnd});
1047 const ServableState error_state = {id, ServableState::ManagerState::kEnd,
1048 errors::Internal(
"Error on load.")};
1049 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
1050 EqualsServableState(error_state));
1058 TEST_P(AspiredVersionsManagerTest, AspireErrorDontUnload) {
1059 const std::vector<ServableId> expected_before = {{kServableName, 0},
1061 {kServableName2, 0},
1062 {kServableName2, 1}};
1063 EXPECT_THAT(manager_->ListAvailableServableIds(),
1064 UnorderedElementsAreArray(expected_before));
1070 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1071 const ServableId
id = {kServableName, 7};
1072 std::unique_ptr<Loader> loader(
1073 new FakeLoader(7, errors::Internal(
"An error.")));
1074 aspired_versions.push_back({id, std::move(loader)});
1075 manager_->GetAspiredVersionsCallback()(kServableName,
1076 std::move(aspired_versions));
1077 HandlePendingAspiredVersionsRequests();
1080 InvokePolicyAndExecuteAction();
1081 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1083 {ServableState::ManagerState::kEnd});
1086 InvokePolicyAndExecuteAction();
1087 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1088 {ServableState::ManagerState::kEnd});
1093 const std::vector<ServableId> expected_after_first_load = {
1094 {kServableName, 1}, {kServableName2, 0}, {kServableName2, 1}};
1095 EXPECT_THAT(manager_->ListAvailableServableIds(),
1096 UnorderedElementsAreArray(expected_after_first_load));
1101 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1102 const ServableId
id = {kServableName, 8};
1103 std::unique_ptr<Loader> loader(
new FakeLoader(8));
1104 aspired_versions.push_back({id, std::move(loader)});
1105 manager_->GetAspiredVersionsCallback()(kServableName,
1106 std::move(aspired_versions));
1107 HandlePendingAspiredVersionsRequests();
1110 InvokePolicyAndExecuteAction();
1111 WaitUntilServableManagerStateIsOneOf(
1112 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
1115 InvokePolicyAndExecuteAction();
1116 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1118 {ServableState::ManagerState::kEnd});
1122 TEST_P(AspiredVersionsManagerTest, UnaspireThenImmediatelyReaspire) {
1129 const ServableId
id = {kServableName, 7};
1131 std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1132 test_util::MockLoader* first_loader =
new NiceMock<test_util::MockLoader>();
1133 first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1134 EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{
id}))
1135 .WillOnce(Return(OkStatus()));
1136 manager_->GetAspiredVersionsCallback()(kServableName,
1137 std::move(first_aspired_versions));
1138 HandlePendingAspiredVersionsRequests();
1144 InvokePolicyAndExecuteAction();
1145 InvokePolicyAndExecuteAction();
1146 WaitUntilServableManagerStateIsOneOf(
1147 servable_state_monitor_,
id, {ServableState::ManagerState::kAvailable});
1148 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
1150 {ServableState::ManagerState::kEnd});
1154 EXPECT_CALL(*first_loader, servable()).WillOnce(InvokeWithoutArgs([&]() {
1155 return AnyPtr{&servable};
1157 auto first_loader_handle =
1158 std::unique_ptr<ServableHandle<int>>(
new ServableHandle<int>);
1159 TF_ASSERT_OK(manager_->GetServableHandle(ServableRequest::FromId(
id),
1160 first_loader_handle.get()));
1166 Notification first_unload_called;
1167 EXPECT_CALL(*first_loader, Unload()).WillOnce(InvokeWithoutArgs([&]() {
1168 first_unload_called.Notify();
1171 std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1172 manager_->GetAspiredVersionsCallback()(kServableName,
1173 std::move(empty_aspired_versions));
1174 HandlePendingAspiredVersionsRequests();
1178 std::unique_ptr<Thread> unload_thread(
1179 Env::Default()->StartThread(ThreadOptions(),
"UnloadThread", [&]() {
1181 InvokePolicyAndExecuteAction();
1182 InvokePolicyAndExecuteAction();
1186 std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1187 test_util::MockLoader* second_loader =
new NiceMock<test_util::MockLoader>();
1188 second_aspired_versions.push_back(
1189 {id, std::unique_ptr<Loader>(second_loader)});
1190 Notification second_load_called;
1191 EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{
id}))
1192 .WillOnce(InvokeWithoutArgs([&]() {
1193 second_load_called.Notify();
1196 manager_->GetAspiredVersionsCallback()(kServableName,
1197 std::move(second_aspired_versions));
1201 std::unique_ptr<Thread> reaspire_thread(
1202 Env::Default()->StartThread(ThreadOptions(),
"ReaspireThread", [&]() {
1203 while (!second_load_called.HasBeenNotified()) {
1205 HandlePendingAspiredVersionsRequests();
1206 InvokePolicyAndExecuteAction();
1207 Env::Default()->SleepForMicroseconds(1000 );
1210 Env::Default()->SleepForMicroseconds(50 * 1000 );
1211 EXPECT_FALSE(first_unload_called.HasBeenNotified());
1212 EXPECT_FALSE(second_load_called.HasBeenNotified());
1216 first_loader_handle =
nullptr;
1217 first_unload_called.WaitForNotification();
1218 second_load_called.WaitForNotification();
1221 TEST_P(AspiredVersionsManagerTest,
1222 UnaspireFailedServableThenImmediatelyReaspire) {
1226 const ServableId
id = {kServableName, 7};
1228 std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1229 test_util::MockLoader* first_loader =
new NiceMock<test_util::MockLoader>();
1230 first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1231 EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{
id}))
1232 .WillRepeatedly(Return(
1233 Status(
static_cast<tsl::errors::Code
>(absl::StatusCode::kUnknown),
1234 "first load failing")));
1235 manager_->GetAspiredVersionsCallback()(kServableName,
1236 std::move(first_aspired_versions));
1237 HandlePendingAspiredVersionsRequests();
1240 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1241 InvokePolicyAndExecuteAction();
1243 WaitUntilServableManagerStateIsOneOf(servable_state_monitor_,
id,
1244 {ServableState::ManagerState::kEnd});
1250 std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1251 manager_->GetAspiredVersionsCallback()(kServableName,
1252 std::move(empty_aspired_versions));
1253 HandlePendingAspiredVersionsRequests();
1256 std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1257 test_util::MockLoader* second_loader =
new NiceMock<test_util::MockLoader>();
1258 second_aspired_versions.push_back(
1259 {id, std::unique_ptr<Loader>(second_loader)});
1260 Notification second_load_called;
1261 EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{
id}))
1262 .WillOnce(InvokeWithoutArgs([&]() {
1263 second_load_called.Notify();
1266 manager_->GetAspiredVersionsCallback()(kServableName,
1267 std::move(second_aspired_versions));
1271 std::unique_ptr<Thread> reaspire_thread(
1272 Env::Default()->StartThread(ThreadOptions(),
"ReaspireThread", [&]() {
1273 while (!second_load_called.HasBeenNotified()) {
1274 HandlePendingAspiredVersionsRequests();
1275 InvokePolicyAndExecuteAction();
1276 Env::Default()->SleepForMicroseconds(1000 );
1279 Env::Default()->SleepForMicroseconds(50 * 1000 );
1280 EXPECT_FALSE(second_load_called.HasBeenNotified());
1285 second_load_called.WaitForNotification();
1288 TEST_P(AspiredVersionsManagerTest, UnaspireNewServableThenImmediatelyReaspire) {
1293 const ServableId
id = {kServableName, 7};
1295 std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
1296 test_util::MockLoader* first_loader =
new NiceMock<test_util::MockLoader>();
1297 EXPECT_CALL(*first_loader, LoadWithMetadata(Loader::Metadata{
id})).Times(0);
1298 first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
1299 manager_->GetAspiredVersionsCallback()(kServableName,
1300 std::move(first_aspired_versions));
1301 HandlePendingAspiredVersionsRequests();
1309 std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
1310 manager_->GetAspiredVersionsCallback()(kServableName,
1311 std::move(empty_aspired_versions));
1312 HandlePendingAspiredVersionsRequests();
1315 std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
1316 test_util::MockLoader* second_loader =
new NiceMock<test_util::MockLoader>();
1317 second_aspired_versions.push_back(
1318 {id, std::unique_ptr<Loader>(second_loader)});
1319 Notification second_load_called;
1320 EXPECT_CALL(*second_loader, LoadWithMetadata(Loader::Metadata{
id}))
1321 .WillOnce(InvokeWithoutArgs([&]() {
1322 second_load_called.Notify();
1325 manager_->GetAspiredVersionsCallback()(kServableName,
1326 std::move(second_aspired_versions));
1329 HandlePendingAspiredVersionsRequests();
1334 HandlePendingAspiredVersionsRequests();
1337 for (
int i = 0; i < kNumVersionsPerServable + 1; ++i) {
1338 InvokePolicyAndExecuteAction();
1340 second_load_called.WaitForNotification();
1343 class MockAspiredVersionPolicy :
public AspiredVersionPolicy {
1345 MOCK_METHOD(absl::optional<ServableAction>, GetNextAction,
1346 (
const std::vector<AspiredServableStateSnapshot>&),
1350 TEST(AspiredVersionsManagerTest, CallPolicyWithAllVersions) {
1351 std::unique_ptr<AspiredVersionsManager> manager;
1352 AspiredVersionsManager::Options manager_options;
1353 MockAspiredVersionPolicy* policy =
new MockAspiredVersionPolicy;
1355 manager_options.manage_state_interval_micros = -1;
1356 manager_options.aspired_version_policy =
1357 std::unique_ptr<AspiredVersionPolicy>(policy);
1359 AspiredVersionsManager::Create(std::move(manager_options), &manager));
1360 std::set<ServableId> servables;
1361 std::vector<ServableData<std::unique_ptr<Loader>>> aspired_versions;
1362 for (
int i = 0; i < kNumVersionsPerServable; ++i) {
1363 const ServableId
id = {kServableName, i};
1364 aspired_versions.push_back(CreateAspiredVersion(
id));
1365 servables.insert(
id);
1367 manager->GetAspiredVersionsCallback()(kServableName,
1368 std::move(aspired_versions));
1369 test_util::AspiredVersionsManagerTestAccess(manager.get())
1370 .HandlePendingAspiredVersionsRequests();
1372 std::vector<AspiredServableStateSnapshot> all_versions;
1373 EXPECT_CALL(*policy, GetNextAction(_))
1376 const std::vector<AspiredServableStateSnapshot>& snapshots) {
1377 all_versions = snapshots;
1378 return absl::nullopt;
1380 test_util::AspiredVersionsManagerTestAccess(manager.get())
1381 .InvokePolicyAndExecuteAction();
1382 EXPECT_EQ(kNumVersionsPerServable, all_versions.size());
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)