16 #ifndef TENSORFLOW_SERVING_CORE_BASIC_MANAGER_H_
17 #define TENSORFLOW_SERVING_CORE_BASIC_MANAGER_H_
22 #include <unordered_map>
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow_serving/core/loader.h"
34 #include "tensorflow_serving/core/loader_harness.h"
35 #include "tensorflow_serving/core/manager.h"
36 #include "tensorflow_serving/core/servable_data.h"
37 #include "tensorflow_serving/core/servable_handle.h"
38 #include "tensorflow_serving/core/servable_id.h"
39 #include "tensorflow_serving/core/servable_state.h"
40 #include "tensorflow_serving/resources/resource_tracker.h"
41 #include "tensorflow_serving/util/event_bus.h"
42 #include "tensorflow_serving/util/executor.h"
43 #include "tensorflow_serving/util/fast_read_dynamic_ptr.h"
45 namespace tensorflow {
49 class BasicManagerTestAccess;
109 using PreLoadHook = std::function<void(
const ServableId&)>;
116 std::unique_ptr<ResourceTracker> resource_tracker;
121 uint32 num_load_threads = 0;
126 uint32 num_unload_threads = 0;
129 std::function<bool(absl::Status)> should_retry_model_load;
139 uint32 max_num_load_retries = 5;
144 int64_t load_retry_interval_micros = 1LL * 60 * 1000 * 1000;
150 bool flush_filesystem_caches =
false;
153 Env* env = Env::Default();
157 PreLoadHook pre_load_hook;
159 static Status Create(
Options options, std::unique_ptr<BasicManager>* manager);
167 Status GetUntypedServableHandle(
169 std::unique_ptr<UntypedServableHandle>* untyped_handle)
override;
171 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
172 GetAvailableUntypedServableHandles()
const override;
190 template <
typename T>
193 std::unique_ptr<T> additional_state);
208 template <
typename T = std::
nullptr_t>
210 const string& servable_name)
const;
217 template <
typename T = std::
nullptr_t>
227 template <
typename T>
275 BasicManager(Env* env, uint32 num_load_threads, uint32 num_unload_threads,
276 uint32 max_num_load_retries,
277 std::function<
bool(absl::Status)> should_retry_model_load,
278 int64_t load_retry_interval_micros,
bool flush_filesystem_caches,
279 std::unique_ptr<ResourceTracker> resource_tracker,
281 PreLoadHook pre_load_hook);
292 Status ManageServableInternal(
ServableData<std::unique_ptr<Loader>> servable,
293 std::function<std::shared_ptr<LoaderHarness>(
299 Status GetHealthyHarness(
const ServableId& servable_id,
301 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
306 std::vector<const Loader*> GetLoadersCurrentlyUsingResources() const
307 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
311 struct LoadOrUnloadRequest {
312 enum class Kind { kLoad, kUnload };
318 void LoadOrUnloadServable(
const LoadOrUnloadRequest& request,
323 void HandleLoadOrUnloadRequest(
const LoadOrUnloadRequest& request,
325 TF_LOCKS_EXCLUDED(mu_);
339 Status ApproveLoadOrUnload(
const LoadOrUnloadRequest& request,
340 LoaderHarness** harness) TF_LOCKS_EXCLUDED(mu_);
350 Status ApproveLoad(LoaderHarness* harness, mutex_lock* mu_lock)
351 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
356 Status ApproveUnload(LoaderHarness* harness) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
365 Status ReserveResources(LoaderHarness* harness, mutex_lock* mu_lock)
366 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
373 Status ExecuteLoadOrUnload(
const LoadOrUnloadRequest& request,
374 LoaderHarness* harness);
377 Status ExecuteLoad(LoaderHarness* harness) TF_LOCKS_EXCLUDED(mu_);
380 Status ExecuteUnload(LoaderHarness* harness) TF_LOCKS_EXCLUDED(mu_);
383 Status UnloadAllServables() TF_LOCKS_EXCLUDED(mu_);
387 void UpdateServingMap() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
401 void SetNumLoadThreads(uint32 num_load_threads)
402 TF_LOCKS_EXCLUDED(load_executor_mu_);
403 uint32 num_load_threads() const;
409 std::unordered_multimap<
string, std::shared_ptr<LoaderHarness>>;
413 ManagedMap::iterator FindHarnessInMap(const ServableId&
id)
414 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
418 void PublishOnEventBus(const ServableState& state);
420 LoaderHarness::Options harness_options_;
424 EventBus<ServableState>* servable_event_bus_;
427 std::function<
bool(absl::Status)> should_retry_model_load_;
435 ManagedMap managed_map_ TF_GUARDED_BY(mu_);
453 Status GetUntypedServableHandle(
454 const ServableRequest& request,
455 std::unique_ptr<UntypedServableHandle>* untyped_handle);
459 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
460 GetAvailableUntypedServableHandles()
const;
464 void Update(
const ManagedMap& managed_map);
477 std::unordered_multimap<ServableRequest,
478 std::shared_ptr<const LoaderHarness>,
479 HashRequest, EqRequest>;
480 FastReadDynamicPtr<HandlesMap> handles_map_;
482 ServingMap serving_map_;
501 std::atomic<uint32> num_load_threads_;
503 const bool flush_filesystem_caches_ =
false;
505 mutable mutex load_executor_mu_;
506 std::unique_ptr<Executor> load_executor_ TF_GUARDED_BY(load_executor_mu_);
510 std::unique_ptr<Executor> unload_executor_;
513 mutable mutex load_unload_decision_phase_mu_;
517 std::unique_ptr<ResourceTracker> resource_tracker_ TF_GUARDED_BY(mu_);
520 int num_ongoing_load_unload_executions_ TF_GUARDED_BY(mu_) = 0;
524 condition_variable num_ongoing_load_unload_executions_cv_;
526 PreLoadHook pre_load_hook_;
528 TF_DISALLOW_COPY_AND_ASSIGN(BasicManager);
535 template <
typename T>
538 std::unique_ptr<T> additional_state) {
539 return ManageServableInternal(
541 [
this, &additional_state](
const ServableId&
id,
542 std::unique_ptr<Loader> loader) {
543 return std::make_shared<LoaderHarness>(
id, std::move(loader),
544 std::move(additional_state),
549 template <
typename T>
550 std::vector<ServableStateSnapshot<T>>
552 const string& servable_name)
const {
555 const auto range = managed_map_.equal_range(servable_name);
556 std::vector<ServableStateSnapshot<T>> state_snapshots;
557 state_snapshots.reserve(std::distance(range.first, range.second));
558 for (
auto it = range.first; it != range.second; ++it) {
559 state_snapshots.push_back(it->second->loader_state_snapshot<T>());
562 return state_snapshots;
565 template <
typename T>
566 absl::optional<ServableStateSnapshot<T>>
570 auto iter = FindHarnessInMap(
id);
571 if (iter == managed_map_.end()) {
572 return absl::nullopt;
574 return iter->second->loader_state_snapshot<T>();
577 template <
typename T>
581 auto iter = FindHarnessInMap(
id);
582 if (iter == managed_map_.end()) {
583 DCHECK(
false) <<
"This servable is not being managed by the mananger: "
587 return iter->second->additional_state<T>();
void LoadServable(const ServableId &id, DoneCallback done_callback)
std::vector< ServableStateSnapshot< T > > GetManagedServableStateSnapshots(const string &servable_name) const
Status StopManagingServable(const ServableId &id)
std::vector< ServableId > ListAvailableServableIds() const override
Status ManageServableWithAdditionalState(ServableData< std::unique_ptr< Loader >> servable, std::unique_ptr< T > additional_state)
void CancelLoadServableRetry(const ServableId &id)
std::function< void(const Status &status)> DoneCallback
void UnloadServable(const ServableId &id, DoneCallback done_callback)
T * GetAdditionalServableState(const ServableId &id)
Status ManageServable(ServableData< std::unique_ptr< Loader >> servable)
absl::optional< ServableStateSnapshot< T > > GetManagedServableStateSnapshot(const ServableId &id)
std::vector< string > GetManagedServableNames() const