16 #include "tensorflow_serving/core/basic_manager.h"
23 #include <unordered_set>
27 #include "absl/status/status.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow_serving/core/servable_handle.h"
33 #include "tensorflow_serving/core/servable_state.h"
34 #include "tensorflow_serving/core/source.h"
35 #include "tensorflow_serving/resources/resource_tracker.h"
36 #include "tensorflow_serving/util/event_bus.h"
37 #include "tensorflow_serving/util/hash.h"
38 #include "tensorflow_serving/util/inline_executor.h"
39 #include "tensorflow_serving/util/retrier.h"
40 #include "tensorflow_serving/util/threadpool_executor.h"
42 namespace tensorflow {
47 std::unique_ptr<Executor> CreateExecutor(Env*
const env,
48 const uint32 num_threads,
49 const string& threadpool_name) {
50 std::unique_ptr<Executor> executor;
51 if (num_threads == 0) {
52 executor.reset(
new InlineExecutor());
54 executor.reset(
new ThreadPoolExecutor(env, threadpool_name, num_threads));
64 if (lhs.version != rhs.version) {
67 if (lhs.auto_version_policy != rhs.auto_version_policy) {
73 if (lhs.name != rhs.name) {
92 const uint64_t version_hash = [&]() -> uint64_t {
93 if (request.version) {
94 return std::hash<int64_t>()(request.version.value()) *
97 switch (request.auto_version_policy) {
98 case ServableRequest::AutoVersionPolicy::kEarliest:
100 case ServableRequest::AutoVersionPolicy::kLatest:
106 return HashCombine(version_hash, std::hash<string>()(request.name));
110 BasicManager::ServingMap::ServingMap()
111 : handles_map_(std::unique_ptr<HandlesMap>(new HandlesMap())) {}
113 std::vector<ServableId> BasicManager::ServingMap::ListAvailableServableIds()
115 std::vector<ServableId> ids;
116 std::shared_ptr<const HandlesMap> handles_map = handles_map_.get();
117 for (
auto iter = handles_map->begin(); iter != handles_map->end();) {
119 const auto key_end = handles_map->equal_range(iter->first).second;
121 for (; iter != key_end; ++iter) {
122 if (iter->first.version) {
123 ids.push_back(iter->second->id());
130 Status BasicManager::ServingMap::GetUntypedServableHandle(
131 const ServableRequest& request,
132 std::unique_ptr<UntypedServableHandle>*
const untyped_handle) {
133 std::shared_ptr<const HandlesMap> handles_map = handles_map_.get();
134 const auto found_it = handles_map->find(request);
135 if (found_it == handles_map->end()) {
136 return errors::NotFound(
"Servable not found for request: ",
137 request.DebugString());
140 const LoaderHarness& harness = *found_it->second;
145 untyped_handle->reset(
new SharedPtrHandle(
146 harness.id(), std::shared_ptr<Loader>(handles_map, harness.loader())));
150 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
151 BasicManager::ServingMap::GetAvailableUntypedServableHandles()
const {
152 std::map<ServableId, std::unique_ptr<UntypedServableHandle>> result;
153 std::shared_ptr<const HandlesMap> handles_map = handles_map_.get();
154 for (
const auto& handle : *handles_map) {
155 const ServableRequest& request = handle.first;
158 if (!request.version) {
161 const LoaderHarness& harness = *handle.second;
162 result.emplace(harness.id(),
163 std::unique_ptr<UntypedServableHandle>(
new SharedPtrHandle(
164 harness.id(), std::shared_ptr<Loader>(
165 handles_map, harness.loader()))));
170 void BasicManager::ServingMap::Update(
const ManagedMap& managed_map) {
171 struct CompareRequests {
172 bool operator()(
const ServableRequest& lhs,
173 const ServableRequest& rhs)
const {
174 const int strcmp_result = lhs.name.compare(rhs.name);
175 if (strcmp_result != 0) {
176 return strcmp_result < 0;
180 return lhs.version.value() < rhs.version.value();
183 std::multimap<ServableRequest, std::shared_ptr<const LoaderHarness>,
185 sorted_available_map;
186 for (
const auto& elem : managed_map) {
187 std::shared_ptr<const LoaderHarness> harness = elem.second;
189 sorted_available_map.emplace(ServableRequest::FromId(harness->id()),
194 std::unique_ptr<HandlesMap> new_handles_map(
new HandlesMap());
195 auto prev_iter = sorted_available_map.end();
196 for (
auto iter = sorted_available_map.begin();
197 iter != sorted_available_map.end(); ++iter) {
198 std::shared_ptr<const LoaderHarness> harness = iter->second;
199 new_handles_map->emplace(ServableRequest::FromId(harness->id()), harness);
203 if (prev_iter == sorted_available_map.end() ||
204 prev_iter->second->id().name != harness->id().name) {
205 const ServableRequest earliest_request =
206 ServableRequest::Earliest(harness->id().name);
207 new_handles_map->emplace(earliest_request, harness);
212 const auto next_iter = std::next(iter);
213 if (next_iter == sorted_available_map.end() ||
214 next_iter->second->id().name != harness->id().name) {
215 const ServableRequest latest_request =
216 ServableRequest::Latest(harness->id().name);
217 new_handles_map->emplace(latest_request, harness);
225 handles_map_.Update(std::move(new_handles_map));
228 Status BasicManager::Create(Options options,
229 std::unique_ptr<BasicManager>* manager) {
230 manager->reset(
new BasicManager(
231 options.env, options.num_load_threads, options.num_unload_threads,
232 options.max_num_load_retries, std::move(options.should_retry_model_load),
233 options.load_retry_interval_micros, options.flush_filesystem_caches,
234 std::move(options.resource_tracker), options.servable_event_bus,
235 std::move(options.pre_load_hook)));
239 BasicManager::BasicManager(
240 Env*
const env,
const uint32 num_load_threads,
241 const uint32 num_unload_threads, uint32 max_num_load_retries,
242 std::function<
bool(absl::Status)> should_retry_model_load,
243 int64_t load_retry_interval_micros,
bool flush_filesystem_caches,
244 std::unique_ptr<ResourceTracker> resource_tracker,
245 EventBus<ServableState>* servable_event_bus,
246 std::function<
void(
const ServableId&)> pre_load_hook)
247 : servable_event_bus_(servable_event_bus),
248 should_retry_model_load_(std::move(should_retry_model_load)),
250 num_load_threads_(num_load_threads),
251 flush_filesystem_caches_(flush_filesystem_caches),
252 pre_load_hook_(std::move(pre_load_hook)) {
253 harness_options_.max_num_load_retries = max_num_load_retries;
254 harness_options_.load_retry_interval_micros = load_retry_interval_micros;
255 harness_options_.error_callback = [
this](
const ServableId& id,
256 const Status& error) {
257 PublishOnEventBus({id, ServableState::ManagerState::kEnd, error});
261 mutex_lock l(load_executor_mu_);
263 CreateExecutor(env_, num_load_threads,
"BasicManager_Load_ThreadPool");
265 unload_executor_ = CreateExecutor(env_, num_unload_threads,
266 "BasicManager_Unload_ThreadPool");
267 resource_tracker_ = std::move(resource_tracker);
273 mutex_lock l(load_executor_mu_);
274 load_executor_.reset();
276 unload_executor_.reset();
278 const Status unload_status = UnloadAllServables();
279 if (!unload_status.ok()) {
280 LOG(ERROR) <<
"Error unloading all servables in BasicManager destructor: "
285 Status BasicManager::UnloadAllServables() {
286 LOG(INFO) <<
"Unload all remaining servables in the manager.";
287 Status status = OkStatus();
290 for (
auto it = managed_map_.begin(); it != managed_map_.end(); ++it) {
296 status.Update(harness->
Unload());
300 status.Update(harness->
Unload());
303 status.Update(harness->
Unload());
312 return serving_map_.ListAvailableServableIds();
315 Status BasicManager::GetUntypedServableHandle(
317 std::unique_ptr<UntypedServableHandle>*
const untyped_handle) {
318 return serving_map_.GetUntypedServableHandle(request, untyped_handle);
321 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
322 BasicManager::GetAvailableUntypedServableHandles()
const {
323 return serving_map_.GetAvailableUntypedServableHandles();
326 void BasicManager::UpdateServingMap() {
329 serving_map_.Update(managed_map_);
332 BasicManager::ManagedMap::iterator BasicManager::FindHarnessInMap(
333 const ServableId&
id) {
334 const auto range = managed_map_.equal_range(
id.name);
335 for (
auto iter = range.first; iter != range.second; ++iter) {
336 if (iter->second->id().version ==
id.version) {
340 return managed_map_.end();
343 Status BasicManager::ManageServableInternal(
344 ServableData<std::unique_ptr<Loader>> servable,
345 std::function<std::shared_ptr<LoaderHarness>(
const ServableId&,
346 std::unique_ptr<Loader>)>
348 VLOG(1) <<
"Request to start managing servable " << servable.id();
352 const auto iter = BasicManager::FindHarnessInMap(servable.id());
353 if (iter != managed_map_.end()) {
354 return errors::FailedPrecondition(
355 "This servable is already being managed: ",
356 servable.id().DebugString());
359 std::unique_ptr<Loader> loader;
360 if (servable.status().ok()) {
361 loader = servable.ConsumeDataOrDie();
364 std::shared_ptr<LoaderHarness> harness =
365 harness_creator(servable.id(), std::move(loader));
366 if (should_retry_model_load_) {
367 harness->set_should_retry(should_retry_model_load_);
369 if (!servable.status().ok()) {
370 harness->Error(servable.status());
372 PublishOnEventBus({harness->id(), ServableState::ManagerState::kStart,
375 managed_map_.emplace(servable.id().name, harness);
382 return ManageServableInternal(
384 [
this](
const ServableId&
id, std::unique_ptr<Loader> loader) {
385 return std::make_shared<LoaderHarness>(
id, std::move(loader),
391 VLOG(1) <<
"Request to stop managing servable " << id;
393 const auto it = FindHarnessInMap(
id);
394 if (it == managed_map_.end()) {
395 LOG(ERROR) <<
"Request to delete harness for " <<
id
396 <<
", but no such harness found in managed_map_";
397 return errors::FailedPrecondition(
"This servable is not being managed: ",
400 const auto state = it->second->state();
404 LOG(ERROR) <<
"Request to delete harness for " <<
id
405 <<
", but it is not in a new or end state. State: " << state;
406 return errors::FailedPrecondition(
407 "This servable is not in a new or end state and we cannot stop "
409 id.DebugString(),
" ", LoaderHarness::StateDebugString(state));
411 managed_map_.erase(it);
415 Status BasicManager::GetHealthyHarness(
const ServableId&
id,
418 auto iter = FindHarnessInMap(
id);
419 if (iter == managed_map_.end()) {
420 return errors::NotFound(
421 "This servable is not being managed by the manager: ",
424 TF_RETURN_IF_ERROR(iter->second->status());
425 *harness = iter->second.get();
429 std::vector<const Loader*> BasicManager::GetLoadersCurrentlyUsingResources()
431 std::vector<const Loader*> loaders;
432 for (
const auto& entry : managed_map_) {
433 const LoaderHarness& harness = *entry.second;
435 switch (harness.state()) {
437 uses_resources =
false;
440 uses_resources =
false;
443 uses_resources =
true;
446 uses_resources =
true;
449 uses_resources =
true;
452 uses_resources =
true;
455 uses_resources =
true;
458 uses_resources =
true;
461 uses_resources =
true;
464 uses_resources =
false;
467 uses_resources =
false;
470 if (uses_resources) {
471 loaders.push_back(harness.loader());
480 std::vector<string> servable_names;
481 for (
auto iter = managed_map_.begin(); iter != managed_map_.end();
482 iter = managed_map_.equal_range(iter->first).second) {
483 servable_names.push_back(iter->first);
485 return servable_names;
489 PublishOnEventBus({harness->
id(), ServableState::ManagerState::kLoading,
494 const ServableId
id = harness->
id();
496 if (pre_load_hook_) {
501 const Status status = harness->
Load();
505 if (flush_filesystem_caches_ && num_load_threads() <= 1) {
506 const Status flush_status = Env::Default()->FlushFileSystemCaches();
507 if (!flush_status.ok()) {
508 LOG(WARNING) <<
"flushing filesystem caches failed: " << flush_status;
512 TF_RETURN_IF_ERROR(status);
519 PublishOnEventBus({id, ServableState::ManagerState::kAvailable, OkStatus()});
525 VLOG(1) <<
"Request to load servable " << id;
526 LoadOrUnloadRequest request;
527 request.kind = LoadOrUnloadRequest::Kind::kLoad;
528 request.servable_id = id;
529 LoadOrUnloadServable(request, done_callback);
535 const Status status = GetHealthyHarness(
id, &harness);
552 {id, ServableState::ManagerState::kUnloading, harness->
status()});
558 TF_RETURN_IF_ERROR(harness->
Unload());
559 PublishOnEventBus({id, ServableState::ManagerState::kEnd, OkStatus()});
565 VLOG(1) <<
"Request to unload servable " << id;
566 LoadOrUnloadRequest request;
567 request.kind = LoadOrUnloadRequest::Kind::kUnload;
568 request.servable_id = id;
569 LoadOrUnloadServable(request, done_callback);
572 Status BasicManager::ExecuteLoadOrUnload(
const LoadOrUnloadRequest& request,
574 Status execution_status;
575 switch (request.kind) {
576 case LoadOrUnloadRequest::Kind::kLoad:
577 execution_status = ExecuteLoad(harness);
579 case LoadOrUnloadRequest::Kind::kUnload:
580 execution_status = ExecuteUnload(harness);
586 --num_ongoing_load_unload_executions_;
587 DCHECK_GE(num_ongoing_load_unload_executions_, 0);
588 num_ongoing_load_unload_executions_cv_.notify_all();
591 return execution_status;
594 void BasicManager::SetNumLoadThreads(
const uint32 num_load_threads) {
606 const uint32 old_num_threads = num_load_threads_.load();
607 if (old_num_threads < 2 || num_load_threads < 2) {
608 mutex_lock l(load_executor_mu_);
609 load_executor_.reset();
610 num_load_threads_.store(num_load_threads);
612 CreateExecutor(env_, num_load_threads,
"BasicManager_Load_ThreadPool");
614 std::unique_ptr<Executor> old_executor;
616 mutex_lock l(load_executor_mu_);
617 old_executor = std::move(load_executor_);
618 num_load_threads_.store(num_load_threads);
619 load_executor_ = CreateExecutor(env_, num_load_threads,
620 "BasicManager_Load_ThreadPool");
625 uint32 BasicManager::num_load_threads()
const {
626 return num_load_threads_.load();
629 void BasicManager::LoadOrUnloadServable(
const LoadOrUnloadRequest& request,
631 const Status status = [&]() {
633 LoaderHarness* harness;
634 TF_RETURN_IF_ERROR(GetHealthyHarness(request.servable_id, &harness));
637 switch (request.kind) {
638 case LoadOrUnloadRequest::Kind::kLoad:
639 TF_RETURN_IF_ERROR(harness->LoadRequested());
641 case LoadOrUnloadRequest::Kind::kUnload:
642 TF_RETURN_IF_ERROR(harness->UnloadRequested());
648 done_callback(status);
652 switch (request.kind) {
653 case LoadOrUnloadRequest::Kind::kLoad: {
654 mutex_lock l(load_executor_mu_);
655 load_executor_->Schedule([
this, request, done_callback]() {
656 HandleLoadOrUnloadRequest(request, done_callback);
660 case LoadOrUnloadRequest::Kind::kUnload: {
661 unload_executor_->Schedule([
this, request, done_callback]() {
662 HandleLoadOrUnloadRequest(request, done_callback);
669 void BasicManager::HandleLoadOrUnloadRequest(
const LoadOrUnloadRequest& request,
672 Status decision_status;
673 LoaderHarness* harness;
678 mutex_lock l(load_unload_decision_phase_mu_);
679 decision_status = ApproveLoadOrUnload(request, &harness);
681 if (!decision_status.ok()) {
682 done_callback(decision_status);
687 const Status execution_status = ExecuteLoadOrUnload(request, harness);
688 done_callback(execution_status);
691 Status BasicManager::ApproveLoadOrUnload(
const LoadOrUnloadRequest& request,
692 LoaderHarness** harness) {
695 TF_RETURN_IF_ERROR(GetHealthyHarness(request.servable_id, harness));
697 switch (request.kind) {
698 case LoadOrUnloadRequest::Kind::kLoad: {
699 TF_RETURN_IF_ERROR(ApproveLoad(*harness, &l));
702 case LoadOrUnloadRequest::Kind::kUnload: {
703 TF_RETURN_IF_ERROR(ApproveUnload(*harness));
708 ++num_ongoing_load_unload_executions_;
713 Status BasicManager::ApproveLoad(LoaderHarness* harness, mutex_lock* mu_lock) {
714 if (resource_tracker_ !=
nullptr) {
716 const Status resource_reservation_status =
717 ReserveResources(harness, mu_lock);
718 if (!resource_reservation_status.ok()) {
719 LOG(WARNING) << resource_reservation_status;
720 harness->Error(resource_reservation_status);
721 PublishOnEventBus({harness->id(), ServableState::ManagerState::kEnd,
722 resource_reservation_status});
723 return resource_reservation_status;
730 TF_RETURN_IF_ERROR(harness->LoadApproved());
735 Status BasicManager::ApproveUnload(LoaderHarness* harness) {
738 TF_RETURN_IF_ERROR(harness->StartQuiescing());
743 Status BasicManager::ReserveResources(LoaderHarness* harness,
744 mutex_lock* mu_lock) {
746 TF_RETURN_IF_ERROR(resource_tracker_->RecomputeUsedResources(
747 GetLoadersCurrentlyUsingResources()));
748 bool resources_reserved;
751 const Status reserve_resources_status = Retry(
752 strings::StrCat(
"Reserving resources for servable: ",
753 harness->id().DebugString()),
754 harness_options_.max_num_load_retries,
755 harness_options_.load_retry_interval_micros,
756 [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
757 return resource_tracker_->ReserveResources(*harness->loader(),
758 &resources_reserved);
760 [&](absl::Status status) {
return harness->should_retry(status); });
761 if (!reserve_resources_status.ok()) {
763 reserve_resources_status.code(),
765 "Error while attempting to reserve resources to load servable ",
766 harness->id().DebugString(),
": ",
767 reserve_resources_status.message()));
769 if (resources_reserved) {
771 LOG(INFO) <<
"Successfully reserved resources to load servable "
772 << harness->id().DebugString();
779 if (num_ongoing_load_unload_executions_ == 0) {
782 return errors::ResourceExhausted(
783 "Insufficient resources to load servable ",
784 harness->id().DebugString());
787 VLOG(1) <<
"Waiting for another load/unload request to finish";
788 num_ongoing_load_unload_executions_cv_.wait(*mu_lock);
793 void BasicManager::PublishOnEventBus(
const ServableState& state) {
794 if (servable_event_bus_ !=
nullptr) {
795 servable_event_bus_->Publish(state);
void LoadServable(const ServableId &id, DoneCallback done_callback)
Status StopManagingServable(const ServableId &id)
std::vector< ServableId > ListAvailableServableIds() const override
void CancelLoadServableRetry(const ServableId &id)
std::function< void(const Status &status)> DoneCallback
void UnloadServable(const ServableId &id, DoneCallback done_callback)
Status ManageServable(ServableData< std::unique_ptr< Loader >> servable)
std::vector< string > GetManagedServableNames() const
void set_should_retry(std::function< bool(absl::Status)> should_retry) TF_LOCKS_EXCLUDED(mu_)
Status DoneQuiescing() TF_LOCKS_EXCLUDED(mu_)
State state() const TF_LOCKS_EXCLUDED(mu_)
Returns the current state of underlying Servable.
Status Unload() TF_LOCKS_EXCLUDED(mu_)
Status UnloadRequested() TF_LOCKS_EXCLUDED(mu_)
Status StartQuiescing() TF_LOCKS_EXCLUDED(mu_)
@ kLoadRequested
The manager has been requested to load this servable.
@ kReady
'loader_->Load()' has succeeded.
@ kUnloadRequested
The manager has been requested to unload this servable.
@ kDisabled
'loader_->Unload()' has finished.
@ kLoading
'loader_->Load()' has been called.
@ kQuiesced
The servable has been made unavailable for serving.
@ kQuiescing
The servable is going to be made unavailable for serving.
@ kUnloading
'loader_->Unload()' has been called.
ServableId id() const
Returns the identifier of underlying Servable.
Status Load() TF_LOCKS_EXCLUDED(mu_)
Status status() const TF_LOCKS_EXCLUDED(mu_)