16 #include "tensorflow_serving/core/caching_manager.h"
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow_serving/core/servable_data.h"
31 #include "tensorflow_serving/core/servable_handle.h"
32 #include "tensorflow_serving/core/servable_id.h"
33 #include "tensorflow_serving/core/servable_state.h"
34 #include "tensorflow_serving/core/servable_state_monitor.h"
35 #include "tensorflow_serving/core/simple_loader.h"
36 #include "tensorflow_serving/core/test_util/fake_loader_source_adapter.h"
37 #include "tensorflow_serving/core/test_util/manager_test_util.h"
38 #include "tensorflow_serving/util/event_bus.h"
39 #include "tensorflow_serving/util/threadpool_executor.h"
41 namespace tensorflow {
45 using ::testing::HasSubstr;
46 using ::testing::UnorderedElementsAreArray;
50 class StringLoaderFactory :
public CachingManager::LoaderFactory {
52 explicit StringLoaderFactory(
const int64_t starting_version)
53 : latest_version_(starting_version) {}
55 ~StringLoaderFactory()
override =
default;
57 ServableData<std::unique_ptr<Loader>> CreateLoader(
58 const ServableId&
id)
override {
62 num_loaders_dispensed_++;
65 auto servable_creator = [&](std::unique_ptr<string>* servable) {
66 servable->reset(
new string);
67 **servable = strings::StrCat(
id.name,
"-",
id.version);
70 std::unique_ptr<Loader> loader;
71 loader.reset(
new SimpleLoader<string>(
72 servable_creator, SimpleLoader<string>::EstimateNoResources()));
73 return ServableData<std::unique_ptr<Loader>>(id, std::move(loader));
77 int64_t GetServableVersion(
78 const string& request_name,
79 ServableRequest::AutoVersionPolicy policy)
const override {
82 case ServableRequest::AutoVersionPolicy::kEarliest:
83 return earliest_version_;
84 case ServableRequest::AutoVersionPolicy::kLatest:
85 return latest_version_;
90 void set_earliest_version(int64_t version) {
92 earliest_version_ = version;
96 void set_latest_version(int64_t version) {
98 latest_version_ = version;
102 int64_t num_loaders_dispensed()
const {
104 return num_loaders_dispensed_;
112 int64_t earliest_version_ TF_GUARDED_BY(mu_) = 0;
115 int64_t latest_version_ TF_GUARDED_BY(mu_) = 0;
118 int64_t num_loaders_dispensed_ TF_GUARDED_BY(mu_) = 0;
120 TF_DISALLOW_COPY_AND_ASSIGN(StringLoaderFactory);
125 class ErrorLoaderFactory :
public CachingManager::LoaderFactory {
127 ErrorLoaderFactory() =
default;
128 ~ErrorLoaderFactory()
override =
default;
130 ServableData<std::unique_ptr<Loader>> CreateLoader(
131 const ServableId&
id)
override {
132 auto servable_creator = [&](std::unique_ptr<string>* servable) {
133 return errors::Unknown(
"error loader-factory");
135 std::unique_ptr<Loader> loader;
136 loader.reset(
new SimpleLoader<string>(
137 servable_creator, SimpleLoader<string>::EstimateNoResources()));
138 return ServableData<std::unique_ptr<Loader>>(id, std::move(loader));
141 int64_t GetServableVersion(
142 const string& request_name,
143 ServableRequest::AutoVersionPolicy policy)
const override {
149 TF_DISALLOW_COPY_AND_ASSIGN(ErrorLoaderFactory);
153 constexpr
char kServableName[] =
"kServableName";
154 constexpr
char kServableName2[] =
"kServableName2";
156 constexpr
int kNumThreads = 10;
160 struct ThreadPoolSizes {
161 uint64_t num_load_threads;
162 uint64_t num_unload_threads;
164 class CachingManagerTest :
public ::testing::TestWithParam<ThreadPoolSizes> {
167 : servable_event_bus_(EventBus<ServableState>::CreateEventBus()),
168 servable_state_monitor_(servable_event_bus_.get()) {
169 CachingManager::Options options;
170 options.env = Env::Default();
171 options.servable_event_bus = servable_event_bus_.get();
172 options.num_load_threads = GetParam().num_load_threads;
173 options.num_unload_threads = GetParam().num_unload_threads;
174 options.max_num_load_retries = 1;
175 options.load_retry_interval_micros = 0;
177 std::unique_ptr<StringLoaderFactory> string_loader_factory;
178 string_loader_factory.reset(
new StringLoaderFactory(0));
179 string_loader_factory_ = string_loader_factory.get();
181 TF_CHECK_OK(CachingManager::Create(
182 std::move(options), std::move(string_loader_factory), &manager_));
188 std::unique_ptr<CachingManager> CreateManagerWithErrorLoaderFactory() {
189 CachingManager::Options options;
190 options.env = Env::Default();
191 options.servable_event_bus = servable_event_bus_.get();
192 options.num_load_threads = GetParam().num_load_threads;
193 options.num_unload_threads = GetParam().num_unload_threads;
194 options.max_num_load_retries = 1;
195 options.load_retry_interval_micros = 0;
197 std::unique_ptr<ErrorLoaderFactory> error_loader_factory;
198 error_loader_factory.reset(
new ErrorLoaderFactory);
200 std::unique_ptr<CachingManager> error_manager;
201 TF_CHECK_OK(CachingManager::Create(
202 std::move(options), std::move(error_loader_factory), &error_manager));
203 return error_manager;
208 int64_t GetLoadMutexMapSize() {
209 return test_util::CachingManagerTestAccess(manager_.get())
210 .GetLoadMutexMapSize();
213 std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
214 ServableStateMonitor servable_state_monitor_;
215 std::unique_ptr<CachingManager> manager_;
216 StringLoaderFactory* string_loader_factory_;
219 INSTANTIATE_TEST_CASE_P(
220 WithOrWithoutThreadPools, CachingManagerTest,
222 ThreadPoolSizes{0, 0} ,
223 ThreadPoolSizes{4, 4} ));
228 TEST_P(CachingManagerTest, ServableHandleSingleRequest) {
230 const ServableId
id = {kServableName, 30};
231 ServableHandle<string> handle;
233 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
234 EXPECT_EQ(
"kServableName-30", *handle);
235 EXPECT_EQ(
id, handle.id());
238 TEST_P(CachingManagerTest, ServableHandleMultipleRequests) {
243 const ServableId
id = {kServableName, 30};
244 ServableHandle<string> handle;
246 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
247 EXPECT_EQ(
"kServableName-30", *handle);
248 EXPECT_EQ(
id, handle.id());
252 const ServableId
id = {kServableName, 31};
253 ServableHandle<string> handle;
255 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
256 EXPECT_EQ(
"kServableName-31", *handle);
257 EXPECT_EQ(
id, handle.id());
263 TEST_P(CachingManagerTest, ServableHandleSingleRequestEarliest) {
264 string_loader_factory_->set_earliest_version(30);
265 ServableHandle<string> handle;
266 TF_ASSERT_OK(manager_->GetServableHandle(
267 ServableRequest::Earliest({kServableName}), &handle));
268 EXPECT_EQ(
"kServableName-30", *handle);
269 const ServableId
id = {kServableName, 30};
270 EXPECT_EQ(
id, handle.id());
275 TEST_P(CachingManagerTest, ServableHandleSingleRequestLatest) {
276 string_loader_factory_->set_latest_version(30);
277 ServableHandle<string> handle;
278 TF_ASSERT_OK(manager_->GetServableHandle(
279 ServableRequest::Latest({kServableName}), &handle));
280 EXPECT_EQ(
"kServableName-30", *handle);
281 const ServableId
id = {kServableName, 30};
282 EXPECT_EQ(
id, handle.id());
287 TEST_P(CachingManagerTest, ServableHandleMultipleRequestsEarliest) {
288 const ServableId
id = {kServableName, 42};
291 ServableHandle<string> handle;
293 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
294 EXPECT_EQ(
"kServableName-42", *handle);
295 EXPECT_EQ(
id, handle.id());
297 EXPECT_EQ(1, string_loader_factory_->num_loaders_dispensed());
299 string_loader_factory_->set_earliest_version(42);
304 ServableHandle<string> handle;
305 TF_ASSERT_OK(manager_->GetServableHandle(
306 ServableRequest::Earliest({kServableName}), &handle));
307 EXPECT_EQ(
"kServableName-42", *handle);
308 EXPECT_EQ(
id, handle.id());
311 EXPECT_EQ(1, string_loader_factory_->num_loaders_dispensed());
317 TEST_P(CachingManagerTest, ServableHandleMultipleRequestsLatest) {
318 const ServableId
id = {kServableName, 42};
321 ServableHandle<string> handle;
323 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
324 EXPECT_EQ(
"kServableName-42", *handle);
325 EXPECT_EQ(
id, handle.id());
327 EXPECT_EQ(1, string_loader_factory_->num_loaders_dispensed());
329 string_loader_factory_->set_latest_version(42);
334 ServableHandle<string> handle;
335 TF_ASSERT_OK(manager_->GetServableHandle(
336 ServableRequest::Latest({kServableName}), &handle));
337 EXPECT_EQ(
"kServableName-42", *handle);
338 EXPECT_EQ(
id, handle.id());
341 EXPECT_EQ(1, string_loader_factory_->num_loaders_dispensed());
345 TEST_P(CachingManagerTest, ServableHandleWrongType) {
348 ServableHandle<int> handle;
349 const Status status = manager_->GetServableHandle(
350 ServableRequest::FromId({kServableName, 30}), &handle);
351 ASSERT_FALSE(status.ok()) << status;
352 EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
355 TEST_P(CachingManagerTest, ServableHandleError) {
357 std::unique_ptr<CachingManager> error_manager =
358 CreateManagerWithErrorLoaderFactory();
359 ServableHandle<string> handle;
360 const Status status = error_manager->GetServableHandle(
361 ServableRequest::FromId({kServableName, 30}), &handle);
362 EXPECT_FALSE(status.ok()) << status;
368 TEST_P(CachingManagerTest, AvailableServableHandlesNoRequests) {
369 std::map<ServableId, ServableHandle<string>> handles =
370 manager_->GetAvailableServableHandles<
string>();
372 EXPECT_EQ(0, handles.size());
375 TEST_P(CachingManagerTest, AvailableServableHandlesMultipleRequests) {
380 ServableHandle<string> handle;
381 TF_ASSERT_OK(manager_->GetServableHandle(
382 ServableRequest::FromId({kServableName, 30}), &handle));
386 ServableHandle<string> handle;
387 TF_ASSERT_OK(manager_->GetServableHandle(
388 ServableRequest::FromId({kServableName, 31}), &handle));
392 ServableHandle<string> handle;
393 TF_ASSERT_OK(manager_->GetServableHandle(
394 ServableRequest::FromId({kServableName2, 32}), &handle));
396 const std::map<ServableId, ServableHandle<string>> handles =
397 manager_->GetAvailableServableHandles<
string>();
398 std::vector<ServableId> actual_keys;
399 for (
const auto& it_handle : handles) {
400 actual_keys.push_back(it_handle.first);
403 const std::vector<ServableId> expected_keys = {
404 {kServableName, 30}, {kServableName, 31}, {kServableName2, 32}};
405 EXPECT_THAT(actual_keys, UnorderedElementsAreArray(expected_keys));
408 TEST_P(CachingManagerTest, AvailableServableHandlesWrongType) {
409 ServableHandle<string> handle;
410 TF_ASSERT_OK(manager_->GetServableHandle(
411 ServableRequest::FromId({kServableName, 30}), &handle));
412 std::map<ServableId, ServableHandle<int>> handles =
413 manager_->GetAvailableServableHandles<
int>();
414 EXPECT_EQ(0, handles.size());
417 TEST_P(CachingManagerTest, AvailableServableHandlesError) {
419 std::unique_ptr<CachingManager> error_manager =
420 CreateManagerWithErrorLoaderFactory();
421 ServableHandle<string> handle;
422 const Status status = error_manager->GetServableHandle(
423 ServableRequest::FromId({kServableName, 30}), &handle);
424 ASSERT_FALSE(status.ok()) << status;
425 std::map<ServableId, ServableHandle<string>> handles =
426 error_manager->GetAvailableServableHandles<
string>();
427 EXPECT_EQ(0, handles.size());
433 TEST_P(CachingManagerTest, ListAvailableServableIdsMultipleRequests) {
436 ServableHandle<string> handle;
437 TF_ASSERT_OK(manager_->GetServableHandle(
438 ServableRequest::FromId({kServableName, 30}), &handle));
442 ServableHandle<string> handle;
443 TF_ASSERT_OK(manager_->GetServableHandle(
444 ServableRequest::FromId({kServableName, 31}), &handle));
448 ServableHandle<string> handle;
449 TF_ASSERT_OK(manager_->GetServableHandle(
450 ServableRequest::FromId({kServableName2, 32}), &handle));
452 const std::vector<ServableId> expected = {
453 {kServableName, 30}, {kServableName, 31}, {kServableName2, 32}};
454 EXPECT_THAT(manager_->ListAvailableServableIds(),
455 UnorderedElementsAreArray(expected));
461 MATCHER_P(EqualsServableState, servable_state, servable_state.DebugString()) {
462 if (arg == servable_state) {
465 *result_listener << arg.DebugString();
469 TEST_P(CachingManagerTest, EventBusSingleRequest) {
470 ServableHandle<string> handle;
471 const ServableId
id = {kServableName, 30};
473 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle));
476 const ServableState expected_published_state = {
477 id, ServableState::ManagerState::kAvailable, OkStatus()};
478 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
479 EqualsServableState(expected_published_state));
482 TEST_P(CachingManagerTest, EventBusErrorHandle) {
484 std::unique_ptr<CachingManager> error_manager =
485 CreateManagerWithErrorLoaderFactory();
486 ServableHandle<string> handle;
487 const ServableId
id = {kServableName, 30};
488 const Status status =
489 error_manager->GetServableHandle(ServableRequest::FromId(
id), &handle);
492 const ServableState expected_published_state = {
493 id, ServableState::ManagerState::kEnd,
494 errors::Unknown(
"error loader-factory")};
495 EXPECT_THAT(*servable_state_monitor_.
GetState(
id),
496 EqualsServableState(expected_published_state));
502 TEST_P(CachingManagerTest, ConcurrentDisjointRequests) {
505 std::vector<Status> statuses(4);
507 ThreadPoolExecutor request_executor(Env::Default(),
"GetHandles",
509 for (
int i = 0; i < 4; i++) {
510 request_executor.Schedule([
this, i, &statuses, &status_mu]() {
511 ServableHandle<string> handle;
512 const Status status =
513 manager_->GetServableHandle({kServableName, i + 30}, &handle);
514 mutex_lock l(status_mu);
515 statuses[i] = status;
520 for (
int i = 0; i < 4; i++) {
521 mutex_lock l(status_mu);
522 EXPECT_EQ(OkStatus(), statuses[i]);
526 const std::map<ServableId, ServableHandle<string>> handles =
527 manager_->GetAvailableServableHandles<
string>();
528 std::vector<ServableId> actual_keys;
529 for (
const auto& it_handle : handles) {
530 actual_keys.push_back(it_handle.first);
533 const std::vector<ServableId> expected_keys = {{kServableName, 30},
536 {kServableName, 33}};
537 EXPECT_THAT(actual_keys, UnorderedElementsAreArray(expected_keys));
540 EXPECT_EQ(0, GetLoadMutexMapSize());
543 TEST_P(CachingManagerTest, ConcurrentIntersectingRequests) {
545 std::vector<Status> statuses(8);
547 ThreadPoolExecutor request_executor(Env::Default(),
"GetHandles",
549 for (
int i = 0; i < 8; i++) {
551 const int version = i % 2 + 30;
552 const ServableId
id = {kServableName, version};
553 request_executor.Schedule([
this, i,
id, &statuses, &status_mu]() {
554 ServableHandle<string> handle;
555 const Status status =
556 manager_->GetServableHandle(ServableRequest::FromId(
id), &handle);
557 mutex_lock l(status_mu);
558 statuses[i] = status;
563 for (
int i = 0; i < 8; i++) {
564 mutex_lock l(status_mu);
565 EXPECT_EQ(OkStatus(), statuses[i]);
569 const std::map<ServableId, ServableHandle<string>> handles =
570 manager_->GetAvailableServableHandles<
string>();
571 std::vector<ServableId> actual_keys;
572 for (
const auto& it_handle : handles) {
573 actual_keys.push_back(it_handle.first);
575 const std::vector<ServableId> expected_keys = {{kServableName, 30},
576 {kServableName, 31}};
577 EXPECT_THAT(actual_keys, UnorderedElementsAreArray(expected_keys));
580 EXPECT_EQ(0, GetLoadMutexMapSize());
585 TEST(PathPrefixLoaderFactoryTest, Basic) {
586 auto adapter = std::unique_ptr<StoragePathSourceAdapter>(
587 new test_util::FakeLoaderSourceAdapter(
"suffix"));
588 PathPrefixLoaderFactory factory(
"prefix", std::move(adapter));
590 ServableData<std::unique_ptr<Loader>> loader_data =
591 factory.CreateLoader({
"servable_name", 0});
592 TF_ASSERT_OK(loader_data.status());
593 std::unique_ptr<Loader> loader = loader_data.ConsumeDataOrDie();
594 TF_ASSERT_OK(loader->Load());
595 EXPECT_EQ(
"prefix/servable_name/suffix", *loader->servable().get<
string>());
597 EXPECT_EQ(0, factory.GetServableVersion(
598 "blah", ServableRequest::AutoVersionPolicy::kEarliest));
599 EXPECT_EQ(0, factory.GetServableVersion(
600 "blah", ServableRequest::AutoVersionPolicy::kLatest));
603 TEST(PathPrefixLoaderFactoryTest, VersionOtherThanZeroYieldsError) {
604 auto adapter = std::unique_ptr<StoragePathSourceAdapter>(
605 new test_util::FakeLoaderSourceAdapter(
"suffix"));
606 PathPrefixLoaderFactory factory(
"prefix", std::move(adapter));
608 ServableData<std::unique_ptr<Loader>> loader_data =
609 factory.CreateLoader({
"servable_name", 42});
610 ASSERT_FALSE(loader_data.status().ok());
611 EXPECT_THAT(loader_data.status().ToString(),
612 HasSubstr(
"PathPrefixLoaderFactory only supports single-version "
613 "servables at version 0"));
absl::optional< ServableState > GetState(const ServableId &servable_id) const TF_LOCKS_EXCLUDED(mu_)