TensorFlow Serving C++ API Documentation
caching_manager.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/caching_manager.h"
17 
18 #include <map>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow_serving/core/loader.h"
28 #include "tensorflow_serving/core/servable_data.h"
29 #include "tensorflow_serving/core/servable_handle.h"
30 #include "tensorflow_serving/core/servable_id.h"
31 
32 namespace tensorflow {
33 namespace serving {
34 
35 Status CachingManager::Create(
36  Options options, std::unique_ptr<LoaderFactory> loader_factory,
37  std::unique_ptr<CachingManager>* caching_manager) {
38  // Set up basic manager options from the caching manager options.
39  BasicManager::Options basic_manager_options;
40  basic_manager_options.resource_tracker = std::move(options.resource_tracker);
41  basic_manager_options.num_load_threads = options.num_load_threads;
42  basic_manager_options.num_unload_threads = options.num_unload_threads;
43  basic_manager_options.max_num_load_retries = options.max_num_load_retries;
44  basic_manager_options.load_retry_interval_micros =
45  options.load_retry_interval_micros;
46  basic_manager_options.env = options.env;
47  basic_manager_options.servable_event_bus = options.servable_event_bus;
48 
49  // Create a basic manager and use it to construct the caching manager.
50  std::unique_ptr<BasicManager> basic_manager;
51  TF_RETURN_IF_ERROR(
52  BasicManager::Create(std::move(basic_manager_options), &basic_manager));
53 
54  caching_manager->reset(
55  new CachingManager(std::move(loader_factory), std::move(basic_manager)));
56  return OkStatus();
57 }
58 
59 CachingManager::CachingManager(std::unique_ptr<LoaderFactory> loader_factory,
60  std::unique_ptr<BasicManager> basic_manager)
61  : loader_factory_(std::move(loader_factory)),
62  basic_manager_(std::move(basic_manager)) {}
63 
64 CachingManager::~CachingManager() {}
65 
66 Status CachingManager::GetUntypedServableHandle(
67  const ServableRequest& request,
68  std::unique_ptr<UntypedServableHandle>* const handle) {
69  if (request.version) {
70  return GetUntypedServableHandleForId({request.name, *request.version},
71  handle);
72  }
73  // Since there is no explicit version in the request, get the latest from the
74  // loader-factory.
75  const int64_t policy_dictated_version = loader_factory_->GetServableVersion(
76  request.name, request.auto_version_policy);
77  return GetUntypedServableHandleForId({request.name, policy_dictated_version},
78  handle);
79 }
80 
81 Status CachingManager::GetUntypedServableHandleForId(
82  const ServableId& servable_id,
83  std::unique_ptr<UntypedServableHandle>* handle) {
84  // Check if the underlying basic manager can already serve this request.
85  const Status handle_status = basic_manager_->GetUntypedServableHandle(
86  ServableRequest::FromId(servable_id), handle);
87 
88  // If the servable is already managed and loaded by the basic manager, serve
89  // it.
90  if (handle_status.ok() || handle_status.code() != error::NOT_FOUND) {
91  return handle_status;
92  }
93 
94  // Build the servable data corresponding to the servable-id.
95  ServableData<std::unique_ptr<Loader>> loader_data =
96  loader_factory_->CreateLoader(servable_id);
97 
98  // Load the servable corresponding to the servable-id. For multiple concurrent
99  // requests enforces that exactly one thread performs the load operation with
100  // the wrapped basic-manager. All other requests block until the load
101  // completes and then trivially succeed.
102  TF_RETURN_IF_ERROR(LoadServable(std::move(loader_data)));
103 
104  // Return the handle using the loaded servable data now.
105  return basic_manager_->GetUntypedServableHandle(
106  ServableRequest::FromId(servable_id), handle);
107 }
108 
109 Status CachingManager::LoadServable(
110  ServableData<std::unique_ptr<Loader>> loader_data) {
111  const ServableId servable_id = loader_data.id();
112 
113  std::shared_ptr<mutex> servable_id_mu;
114  {
115  mutex_lock l(load_mutex_map_mu_);
116  auto iter = load_mutex_map_.find(servable_id);
117  if (iter == load_mutex_map_.end()) {
118  iter =
119  load_mutex_map_.emplace(servable_id, std::make_shared<mutex>()).first;
120  }
121  servable_id_mu = iter->second;
122  }
123 
124  {
125  // Ensure only one thread attempts to load the servable at a time.
126  mutex_lock l(*servable_id_mu);
127 
128  // Retrieve the state of the servable from the wrapped basic-manager. The
129  // servable should already be managed by the basic-manager.
130  const absl::optional<ServableStateSnapshot<>> snapshot =
131  basic_manager_->GetManagedServableStateSnapshot(servable_id);
132  if (snapshot) {
133  // The servable is already being managed by 'basic_manager_'. Hence it
134  // ought to be loaded, based on CachingManager's implementation invariant
135  // of doing manage+load atomically.
136  if (snapshot.value().state != LoaderHarness::State::kReady) {
137  const string error_msg = strings::StrCat(
138  "Servable requested for load is already being managed, but is not "
139  "loaded: ",
140  servable_id.DebugString());
141  DCHECK(false) << error_msg;
142  return errors::Internal(error_msg);
143  }
144  } else {
145  // Load the servable since it has not been loaded yet based on its state.
146  //
147  // First, transfer the servable to the basic manager. The loader_data may
148  // contain an error and the basic manager is equipped to handle that
149  // appropriately. By propagating such errors back to the basic manager,
150  // the functionality of the event-bus and the servable state monitor are
151  // automatically available in the caching-manager as well (via the basic
152  // manager).
153  const Status manage_status =
154  basic_manager_->ManageServable(std::move(loader_data));
155  if (!manage_status.ok()) {
156  const string error_msg = strings::StrCat(
157  "Internal error: unable to transfer servable to 'basic_manager_': ",
158  manage_status.message());
159  DCHECK(false) << error_msg;
160  return errors::Internal(error_msg);
161  }
162 
163  Notification load_done;
164  Status load_status;
165  basic_manager_->LoadServable(servable_id, [&](const Status& status) {
166  load_status = status;
167  load_done.Notify();
168  });
169  load_done.WaitForNotification();
170  TF_RETURN_IF_ERROR(load_status);
171  }
172  }
173  servable_id_mu.reset();
174  MaybeEraseLoadMutexMapEntry(servable_id);
175  return OkStatus();
176 }
177 
178 void CachingManager::MaybeEraseLoadMutexMapEntry(
179  const ServableId& servable_id) {
180  mutex_lock l(load_mutex_map_mu_);
181  auto iter = load_mutex_map_.find(servable_id);
182  // Erase the entry from the map if one exists and if the mutex shared_ptr
183  // is the last remaining one.
184  if (iter != load_mutex_map_.end() && iter->second.unique()) {
185  load_mutex_map_.erase(iter);
186  }
187 }
188 
189 int64_t CachingManager::GetLoadMutexMapSize() const {
190  mutex_lock l(load_mutex_map_mu_);
191  return load_mutex_map_.size();
192 }
193 
194 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
195 CachingManager::GetAvailableUntypedServableHandles() const {
196  return basic_manager_->GetAvailableUntypedServableHandles();
197 }
198 
199 std::vector<ServableId> CachingManager::ListAvailableServableIds() const {
200  return basic_manager_->ListAvailableServableIds();
201 }
202 
203 PathPrefixLoaderFactory::PathPrefixLoaderFactory(
204  const string& path_prefix,
205  std::unique_ptr<StoragePathSourceAdapter> adapter)
206  : path_prefix_(path_prefix), adapter_(std::move(adapter)) {}
207 
209  const ServableId& id) {
210  if (id.version != 0) {
212  id,
213  errors::FailedPrecondition("PathPrefixLoaderFactory only supports "
214  "single-version servables at version 0"));
215  }
216  const StoragePath servable_path = io::JoinPath(path_prefix_, id.name);
217  return adapter_->AdaptOneVersion({id, servable_path});
218 }
219 
221  const string& servable_name,
222  ServableRequest::AutoVersionPolicy policy) const {
223  return 0;
224 }
225 
226 } // namespace serving
227 } // namespace tensorflow
int64_t GetServableVersion(const string &servable_name, ServableRequest::AutoVersionPolicy policy) const override
ServableData< std::unique_ptr< Loader > > CreateLoader(const ServableId &id) override