16 #include "tensorflow_serving/core/caching_manager.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow_serving/core/loader.h"
28 #include "tensorflow_serving/core/servable_data.h"
29 #include "tensorflow_serving/core/servable_handle.h"
30 #include "tensorflow_serving/core/servable_id.h"
32 namespace tensorflow {
35 Status CachingManager::Create(
36 Options options, std::unique_ptr<LoaderFactory> loader_factory,
37 std::unique_ptr<CachingManager>* caching_manager) {
39 BasicManager::Options basic_manager_options;
40 basic_manager_options.resource_tracker = std::move(options.resource_tracker);
41 basic_manager_options.num_load_threads = options.num_load_threads;
42 basic_manager_options.num_unload_threads = options.num_unload_threads;
43 basic_manager_options.max_num_load_retries = options.max_num_load_retries;
44 basic_manager_options.load_retry_interval_micros =
45 options.load_retry_interval_micros;
46 basic_manager_options.env = options.env;
47 basic_manager_options.servable_event_bus = options.servable_event_bus;
50 std::unique_ptr<BasicManager> basic_manager;
52 BasicManager::Create(std::move(basic_manager_options), &basic_manager));
54 caching_manager->reset(
55 new CachingManager(std::move(loader_factory), std::move(basic_manager)));
59 CachingManager::CachingManager(std::unique_ptr<LoaderFactory> loader_factory,
60 std::unique_ptr<BasicManager> basic_manager)
61 : loader_factory_(std::move(loader_factory)),
62 basic_manager_(std::move(basic_manager)) {}
64 CachingManager::~CachingManager() {}
66 Status CachingManager::GetUntypedServableHandle(
67 const ServableRequest& request,
68 std::unique_ptr<UntypedServableHandle>*
const handle) {
69 if (request.version) {
70 return GetUntypedServableHandleForId({request.name, *request.version},
75 const int64_t policy_dictated_version = loader_factory_->GetServableVersion(
76 request.name, request.auto_version_policy);
77 return GetUntypedServableHandleForId({request.name, policy_dictated_version},
81 Status CachingManager::GetUntypedServableHandleForId(
82 const ServableId& servable_id,
83 std::unique_ptr<UntypedServableHandle>* handle) {
85 const Status handle_status = basic_manager_->GetUntypedServableHandle(
86 ServableRequest::FromId(servable_id), handle);
90 if (handle_status.ok() || handle_status.code() != error::NOT_FOUND) {
95 ServableData<std::unique_ptr<Loader>> loader_data =
96 loader_factory_->CreateLoader(servable_id);
102 TF_RETURN_IF_ERROR(LoadServable(std::move(loader_data)));
105 return basic_manager_->GetUntypedServableHandle(
106 ServableRequest::FromId(servable_id), handle);
109 Status CachingManager::LoadServable(
110 ServableData<std::unique_ptr<Loader>> loader_data) {
111 const ServableId servable_id = loader_data.id();
113 std::shared_ptr<mutex> servable_id_mu;
115 mutex_lock l(load_mutex_map_mu_);
116 auto iter = load_mutex_map_.find(servable_id);
117 if (iter == load_mutex_map_.end()) {
119 load_mutex_map_.emplace(servable_id, std::make_shared<mutex>()).first;
121 servable_id_mu = iter->second;
126 mutex_lock l(*servable_id_mu);
130 const absl::optional<ServableStateSnapshot<>> snapshot =
131 basic_manager_->GetManagedServableStateSnapshot(servable_id);
136 if (snapshot.value().state != LoaderHarness::State::kReady) {
137 const string error_msg = strings::StrCat(
138 "Servable requested for load is already being managed, but is not "
140 servable_id.DebugString());
141 DCHECK(
false) << error_msg;
142 return errors::Internal(error_msg);
153 const Status manage_status =
154 basic_manager_->ManageServable(std::move(loader_data));
155 if (!manage_status.ok()) {
156 const string error_msg = strings::StrCat(
157 "Internal error: unable to transfer servable to 'basic_manager_': ",
158 manage_status.message());
159 DCHECK(
false) << error_msg;
160 return errors::Internal(error_msg);
163 Notification load_done;
165 basic_manager_->LoadServable(servable_id, [&](
const Status& status) {
166 load_status = status;
169 load_done.WaitForNotification();
170 TF_RETURN_IF_ERROR(load_status);
173 servable_id_mu.reset();
174 MaybeEraseLoadMutexMapEntry(servable_id);
178 void CachingManager::MaybeEraseLoadMutexMapEntry(
179 const ServableId& servable_id) {
180 mutex_lock l(load_mutex_map_mu_);
181 auto iter = load_mutex_map_.find(servable_id);
184 if (iter != load_mutex_map_.end() && iter->second.unique()) {
185 load_mutex_map_.erase(iter);
189 int64_t CachingManager::GetLoadMutexMapSize()
const {
190 mutex_lock l(load_mutex_map_mu_);
191 return load_mutex_map_.size();
194 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
195 CachingManager::GetAvailableUntypedServableHandles()
const {
196 return basic_manager_->GetAvailableUntypedServableHandles();
199 std::vector<ServableId> CachingManager::ListAvailableServableIds()
const {
200 return basic_manager_->ListAvailableServableIds();
203 PathPrefixLoaderFactory::PathPrefixLoaderFactory(
204 const string& path_prefix,
205 std::unique_ptr<StoragePathSourceAdapter> adapter)
206 : path_prefix_(path_prefix), adapter_(std::move(adapter)) {}
210 if (
id.version != 0) {
213 errors::FailedPrecondition(
"PathPrefixLoaderFactory only supports "
214 "single-version servables at version 0"));
216 const StoragePath servable_path = io::JoinPath(path_prefix_,
id.name);
217 return adapter_->AdaptOneVersion({id, servable_path});
221 const string& servable_name,
222 ServableRequest::AutoVersionPolicy policy)
const {
int64_t GetServableVersion(const string &servable_name, ServableRequest::AutoVersionPolicy policy) const override
ServableData< std::unique_ptr< Loader > > CreateLoader(const ServableId &id) override