TensorFlow Serving C++ API Documentation
load_servables_fast.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/load_servables_fast.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow_serving/core/servable_state.h"
30 #include "tensorflow_serving/core/source.h"
31 #include "tensorflow_serving/core/target.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 
36 namespace internal {
37 
38 uint32 GetManagerNumLoadThreads(AspiredVersionsManager* manager) {
39  return manager->num_load_threads();
40 }
41 
42 std::function<void(const uint32)> SetManagerNumLoadThreadsNotifier(
43  AspiredVersionsManager* manager) {
44  return manager->set_num_load_threads_observer_->Notifier();
45 }
46 
47 Status ConnectSourcesWithFastInitialLoad(
48  AspiredVersionsManager* manager,
49  std::vector<Source<std::unique_ptr<Loader>>*> sources,
50  const std::function<Status()>& wait_until_loaded_fn,
51  const uint32 num_threads) {
52  const uint32 prev_num_load_threads = GetManagerNumLoadThreads(manager);
53  std::function<void(const uint32)> set_manager_num_load_threads =
54  SetManagerNumLoadThreadsNotifier(manager);
55  set_manager_num_load_threads(num_threads);
56  for (Source<std::unique_ptr<Loader>>* source : sources) {
57  ConnectSourceToTarget(source, manager);
58  }
59  const Status status = wait_until_loaded_fn();
60  set_manager_num_load_threads(prev_num_load_threads);
61  return status;
62 }
63 
64 } // namespace internal
65 
66 Status ConnectSourceWithFastInitialLoad(
67  AspiredVersionsManager* manager, Source<std::unique_ptr<Loader>>* source,
68  ServableStateMonitor* servable_state_monitor,
69  const std::vector<ServableRequest>& initial_servables,
70  const uint32 num_threads) {
71  return ConnectSourcesWithFastInitialLoad(manager, {source},
72  servable_state_monitor,
73  initial_servables, num_threads);
74 }
75 
76 Status ConnectSourcesWithFastInitialLoad(
77  AspiredVersionsManager* manager,
78  std::vector<Source<std::unique_ptr<Loader>>*> sources,
79  ServableStateMonitor* servable_state_monitor,
80  const std::vector<ServableRequest>& initial_servables,
81  const uint32 num_threads) {
82  return internal::ConnectSourcesWithFastInitialLoad(
83  manager, sources,
84  [&]() {
85  std::map<ServableId, ServableState::ManagerState> states_reached;
86  const bool all_servables_available =
87  servable_state_monitor->WaitUntilServablesReachState(
88  initial_servables, ServableState::ManagerState::kAvailable,
89  &states_reached);
90  if (!all_servables_available) {
91  const int num_unavailable_servables = std::count_if(
92  states_reached.begin(), states_reached.end(),
93  [](const std::pair<ServableId, ServableState::ManagerState>&
94  id_and_state) {
95  return id_and_state.second !=
96  ServableState::ManagerState::kAvailable;
97  });
98  string message =
99  strings::StrCat(num_unavailable_servables,
100  " servable(s) did not become available: {");
101  for (const auto& id_and_state : states_reached) {
102  if (id_and_state.second !=
103  ServableState::ManagerState::kAvailable) {
104  absl::optional<ServableState> maybe_state =
105  servable_state_monitor->GetState(id_and_state.first);
106  const string error_msg =
107  maybe_state && !maybe_state.value().health.ok()
108  ? " due to error: " +
109  maybe_state.value().health.ToString()
110  : "";
111  strings::StrAppend(&message, "{",
112  id_and_state.first.DebugString(), error_msg,
113  "}, ");
114  }
115  }
116  strings::StrAppend(&message, "}");
117  return errors::Unknown(message);
118  }
119  return OkStatus();
120  },
121  num_threads);
122 }
123 
124 } // namespace serving
125 } // namespace tensorflow