16 #include "tensorflow_serving/core/load_servables_fast.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow_serving/core/servable_state.h"
30 #include "tensorflow_serving/core/source.h"
31 #include "tensorflow_serving/core/target.h"
33 namespace tensorflow {
38 uint32 GetManagerNumLoadThreads(AspiredVersionsManager* manager) {
39 return manager->num_load_threads();
42 std::function<void(
const uint32)> SetManagerNumLoadThreadsNotifier(
43 AspiredVersionsManager* manager) {
44 return manager->set_num_load_threads_observer_->Notifier();
47 Status ConnectSourcesWithFastInitialLoad(
48 AspiredVersionsManager* manager,
49 std::vector<Source<std::unique_ptr<Loader>>*> sources,
50 const std::function<Status()>& wait_until_loaded_fn,
51 const uint32 num_threads) {
52 const uint32 prev_num_load_threads = GetManagerNumLoadThreads(manager);
53 std::function<void(
const uint32)> set_manager_num_load_threads =
54 SetManagerNumLoadThreadsNotifier(manager);
55 set_manager_num_load_threads(num_threads);
56 for (Source<std::unique_ptr<Loader>>* source : sources) {
57 ConnectSourceToTarget(source, manager);
59 const Status status = wait_until_loaded_fn();
60 set_manager_num_load_threads(prev_num_load_threads);
66 Status ConnectSourceWithFastInitialLoad(
67 AspiredVersionsManager* manager, Source<std::unique_ptr<Loader>>* source,
68 ServableStateMonitor* servable_state_monitor,
69 const std::vector<ServableRequest>& initial_servables,
70 const uint32 num_threads) {
71 return ConnectSourcesWithFastInitialLoad(manager, {source},
72 servable_state_monitor,
73 initial_servables, num_threads);
76 Status ConnectSourcesWithFastInitialLoad(
77 AspiredVersionsManager* manager,
78 std::vector<Source<std::unique_ptr<Loader>>*> sources,
79 ServableStateMonitor* servable_state_monitor,
80 const std::vector<ServableRequest>& initial_servables,
81 const uint32 num_threads) {
82 return internal::ConnectSourcesWithFastInitialLoad(
85 std::map<ServableId, ServableState::ManagerState> states_reached;
86 const bool all_servables_available =
87 servable_state_monitor->WaitUntilServablesReachState(
88 initial_servables, ServableState::ManagerState::kAvailable,
90 if (!all_servables_available) {
91 const int num_unavailable_servables = std::count_if(
92 states_reached.begin(), states_reached.end(),
93 [](
const std::pair<ServableId, ServableState::ManagerState>&
95 return id_and_state.second !=
96 ServableState::ManagerState::kAvailable;
99 strings::StrCat(num_unavailable_servables,
100 " servable(s) did not become available: {");
101 for (
const auto& id_and_state : states_reached) {
102 if (id_and_state.second !=
103 ServableState::ManagerState::kAvailable) {
104 absl::optional<ServableState> maybe_state =
105 servable_state_monitor->GetState(id_and_state.first);
106 const string error_msg =
107 maybe_state && !maybe_state.value().health.ok()
108 ?
" due to error: " +
109 maybe_state.value().health.ToString()
111 strings::StrAppend(&message,
"{",
112 id_and_state.first.DebugString(), error_msg,
116 strings::StrAppend(&message,
"}");
117 return errors::Unknown(message);