16 #include "tensorflow_serving/core/aspired_versions_manager.h"
24 #include <unordered_set>
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/context.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow_serving/core/loader_harness.h"
34 #include "tensorflow_serving/core/servable_handle.h"
35 #include "tensorflow_serving/core/source.h"
37 namespace tensorflow {
54 struct CompareActions {
56 explicit CompareActions(
57 AspiredVersionsManager::CustomSortActionsFn custom_sort_actions)
58 : custom_sort_actions_(custom_sort_actions) {}
61 const absl::optional<AspiredVersionPolicy::ServableAction>& lhs,
62 const absl::optional<AspiredVersionPolicy::ServableAction>& rhs) {
70 if (custom_sort_actions_) {
71 return custom_sort_actions_(lhs.value(), rhs.value());
73 return OrderActions(lhs.value(), rhs.value()).action != rhs.value().action;
77 AspiredVersionPolicy::ServableAction OrderActions(
78 const AspiredVersionPolicy::ServableAction& lhs,
79 const AspiredVersionPolicy::ServableAction& rhs) {
90 AspiredVersionsManager::CustomSortActionsFn custom_sort_actions_;
95 Status ValidateAspiredVersions(
96 const StringPiece servable_name,
97 const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
98 for (
const auto& version : versions) {
99 if (servable_name != version.id().name) {
100 return errors::InvalidArgument(strings::StrCat(
101 "Servable name: ", servable_name,
102 " doesn't match name in servable version: ", version.id().name));
109 std::set<int64_t> GetVersionNumbers(
110 const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
111 std::set<int64_t> version_numbers;
112 for (
const auto& version : versions) {
113 version_numbers.insert(version.id().version);
115 return version_numbers;
119 string ServableVersionsDebugString(
120 const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
121 std::vector<string> version_strings;
122 version_strings.reserve(versions.size());
123 for (
const ServableData<std::unique_ptr<Loader>>& version : versions) {
124 version_strings.push_back(version.id().DebugString());
126 return str_util::Join(version_strings,
", ");
135 :
public TargetBase<std::unique_ptr<Loader>> {
143 void SetAspiredVersions(
144 const StringPiece servable_name,
145 std::vector<
ServableData<std::unique_ptr<Loader>>> versions)
override {
146 parent_->EnqueueAspiredVersionsRequest(servable_name, std::move(versions));
158 Status AspiredVersionsManager::Create(
159 Options options, std::unique_ptr<AspiredVersionsManager>* manager) {
160 if (options.aspired_version_policy ==
nullptr) {
161 return errors::InvalidArgument(
162 "AspiredVersionsManager::Options aspired_version_policy must be "
165 BasicManager::Options basic_manager_options;
166 basic_manager_options.resource_tracker = std::move(options.resource_tracker);
167 basic_manager_options.num_load_threads = options.num_load_threads;
168 basic_manager_options.num_unload_threads = options.num_unload_threads;
169 basic_manager_options.max_num_load_retries = options.max_num_load_retries;
170 basic_manager_options.load_retry_interval_micros =
171 options.load_retry_interval_micros;
172 basic_manager_options.flush_filesystem_caches =
173 options.flush_filesystem_caches;
174 basic_manager_options.env = options.env;
175 basic_manager_options.servable_event_bus = options.servable_event_bus;
176 basic_manager_options.pre_load_hook = std::move(options.pre_load_hook);
177 if (options.should_retry_model_load) {
178 basic_manager_options.should_retry_model_load =
179 std::move(options.should_retry_model_load);
181 std::unique_ptr<BasicManager> basic_manager;
183 BasicManager::Create(std::move(basic_manager_options), &basic_manager));
185 manager->reset(
new AspiredVersionsManager(
186 options.manage_state_interval_micros, options.env,
187 std::move(options.aspired_version_policy),
188 std::move(options.custom_sort_actions), std::move(basic_manager),
189 options.with_current_context));
190 (manager->get())->enable_reload_servables_with_error_ =
191 options.enable_reload_servables_with_error;
195 AspiredVersionsManager::AspiredVersionsManager(
196 int64_t manage_state_interval_micros, Env* env,
197 std::unique_ptr<AspiredVersionPolicy> aspired_version_policy,
198 CustomSortActionsFn custom_sort_actions,
199 std::unique_ptr<BasicManager> basic_manager,
bool with_current_context)
200 : aspired_version_policy_(std::move(aspired_version_policy)),
201 custom_sort_actions_(std::move(custom_sort_actions)),
202 target_impl_(new internal::AspiredVersionsManagerTargetImpl(this)),
203 basic_manager_(std::move(basic_manager)) {
204 set_num_load_threads_observer_.reset(
205 new Observer<const uint32>([
this](
const uint32 num_load_threads) {
206 this->SetNumLoadThreads(num_load_threads);
208 if (manage_state_interval_micros > 0) {
209 PeriodicFunction::Options pf_options;
210 pf_options.env = env;
211 pf_options.thread_name_prefix =
"AspiredVersionsManager_ManageState_Thread";
212 if (with_current_context) {
213 tensorflow::Context context(tensorflow::ContextKind::kThread);
214 manage_state_thread_.reset(
new PeriodicFunction(
215 [
this, context = std::move(context)]() {
216 tensorflow::WithContext wc(context);
217 this->FlushServables();
218 this->HandlePendingAspiredVersionsRequests();
219 this->InvokePolicyAndExecuteAction();
221 manage_state_interval_micros));
223 manage_state_thread_.reset(
new PeriodicFunction(
225 this->FlushServables();
226 this->HandlePendingAspiredVersionsRequests();
227 this->InvokePolicyAndExecuteAction();
229 manage_state_interval_micros));
234 AspiredVersionsManager::~AspiredVersionsManager() {
237 target_impl_.reset();
240 manage_state_thread_.reset();
245 return basic_manager_->ListAvailableServableIds();
248 Status AspiredVersionsManager::GetUntypedServableHandle(
250 std::unique_ptr<UntypedServableHandle>*
const untyped_handle) {
251 return basic_manager_->GetUntypedServableHandle(request, untyped_handle);
254 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
255 AspiredVersionsManager::GetAvailableUntypedServableHandles()
const {
256 return basic_manager_->GetAvailableUntypedServableHandles();
259 Source<std::unique_ptr<Loader>>::AspiredVersionsCallback
261 return target_impl_->GetAspiredVersionsCallback();
264 void AspiredVersionsManager::EnqueueAspiredVersionsRequest(
265 const StringPiece servable_name,
266 std::vector<
ServableData<std::unique_ptr<Loader>>> versions) {
267 const Status validation_status =
268 ValidateAspiredVersions(servable_name, versions);
269 DCHECK(validation_status.ok()) << validation_status.message();
270 if (!validation_status.ok()) {
271 LOG(ERROR) << validation_status.message();
276 mutex_lock l(pending_aspired_versions_requests_mu_);
277 VLOG(2) <<
"Enqueueing aspired versions request: " << servable_name <<
": "
278 << ServableVersionsDebugString(versions);
279 pending_aspired_versions_requests_[string(servable_name)] =
284 void AspiredVersionsManager::ProcessAspiredVersionsRequest(
285 const StringPiece servable_name,
286 std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
287 VLOG(2) <<
"Processing aspired versions request: " << servable_name <<
": "
288 << ServableVersionsDebugString(versions);
290 const std::set<int64_t> next_aspired_versions = GetVersionNumbers(versions);
296 std::set<int64_t> current_aspired_versions;
297 std::set<int64_t> current_aspired_versions_with_error;
298 const std::vector<ServableStateSnapshot<Aspired>> state_snapshots =
299 basic_manager_->GetManagedServableStateSnapshots<Aspired>(
300 string(servable_name));
301 for (
const ServableStateSnapshot<Aspired>& state_snapshot : state_snapshots) {
302 if (state_snapshot.additional_state->is_aspired) {
303 current_aspired_versions.insert(state_snapshot.id.version);
305 current_aspired_versions_with_error.insert(state_snapshot.id.version);
309 if (next_aspired_versions.find(state_snapshot.id.version) ==
310 next_aspired_versions.end()) {
311 VLOG(1) <<
"Setting is_aspired=false for " << state_snapshot.id;
312 basic_manager_->GetAdditionalServableState<Aspired>(state_snapshot.id)
313 ->is_aspired =
false;
314 basic_manager_->CancelLoadServableRetry(state_snapshot.id);
321 std::set<int64_t> additions;
323 next_aspired_versions.begin(), next_aspired_versions.end(),
324 current_aspired_versions.begin(), current_aspired_versions.end(),
325 std::inserter(additions, additions.begin()));
329 for (
auto& version : versions) {
330 bool should_add =
false;
331 const auto& version_id = version.id();
332 if (additions.find(version.id().version) != additions.end()) {
335 if (enable_reload_servables_with_error_ &&
336 current_aspired_versions_with_error.find(version.id().version) !=
337 current_aspired_versions_with_error.end()) {
339 id.name = std::string(servable_name);
340 id.version = version_id.version;
341 const Status manage_status = basic_manager_->StopManagingServable(
id);
342 DCHECK(manage_status.ok()) << manage_status.message();
343 if (!manage_status.ok()) {
344 LOG(ERROR) <<
"Internal error: Unable to clear errored servable "
345 << version_id.DebugString()
346 <<
" from 'basic_manager_': " << manage_status.message();
353 const Status manage_status =
354 basic_manager_->ManageServableWithAdditionalState(
355 std::move(version), std::unique_ptr<Aspired>(
new Aspired{
true}));
356 DCHECK(manage_status.ok()) << manage_status.message();
357 if (!manage_status.ok()) {
358 LOG(ERROR) <<
"Internal error: Unable to transfer servable "
359 << version_id.DebugString()
360 <<
" to 'basic_manager_': " << manage_status.message();
366 bool AspiredVersionsManager::ContainsAnyReaspiredVersions(
367 const StringPiece servable_name,
368 const std::vector<ServableData<std::unique_ptr<Loader>>>& versions)
const {
369 const std::vector<ServableStateSnapshot<Aspired>> state_snapshots =
370 basic_manager_->GetManagedServableStateSnapshots<Aspired>(
371 string(servable_name));
372 const std::set<int64_t> version_numbers = GetVersionNumbers(versions);
373 for (
const ServableStateSnapshot<Aspired>& state_snapshot : state_snapshots) {
374 if (!state_snapshot.additional_state->is_aspired &&
375 version_numbers.find(state_snapshot.id.version) !=
376 version_numbers.end()) {
385 absl::optional<AspiredVersionPolicy::ServableAction>
386 AspiredVersionsManager::GetNextAction() {
387 std::vector<absl::optional<AspiredVersionPolicy::ServableAction>> actions;
388 for (
const string& servable_name :
389 basic_manager_->GetManagedServableNames()) {
390 std::vector<AspiredServableStateSnapshot> aspired_state_snapshots;
391 for (
const ServableStateSnapshot<Aspired>& state_snapshot :
392 basic_manager_->GetManagedServableStateSnapshots<Aspired>(
394 aspired_state_snapshots.push_back(
395 {state_snapshot.id, state_snapshot.state,
396 state_snapshot.additional_state->is_aspired});
398 actions.emplace_back(
399 aspired_version_policy_->GetNextAction(aspired_state_snapshots));
402 std::sort(actions.begin(), actions.end(),
403 CompareActions(custom_sort_actions_));
404 const absl::optional<AspiredVersionPolicy::ServableAction> next_action =
405 !actions.empty() ? actions[0] : absl::nullopt;
407 VLOG(1) <<
"Taking action: " << next_action->DebugString();
412 void AspiredVersionsManager::PerformAction(
413 const AspiredVersionPolicy::ServableAction action) {
414 switch (action.action) {
416 basic_manager_->LoadServable(action.id, [action](
const Status& status) {
418 LOG(ERROR) <<
"Servable " << action.id.DebugString()
419 <<
" cannot be loaded: " << status;
424 basic_manager_->UnloadServable(action.id, [action](
const Status& status) {
426 LOG(ERROR) <<
"Servable " << action.id.DebugString()
427 <<
" cannot be unloaded: " << status;
434 void AspiredVersionsManager::FlushServables() {
435 mutex_lock l(basic_manager_read_modify_write_mu_);
436 for (
const string& servable_name :
437 basic_manager_->GetManagedServableNames()) {
438 for (
const ServableStateSnapshot<Aspired>& state_snapshot :
439 basic_manager_->GetManagedServableStateSnapshots<Aspired>(
441 if ((state_snapshot.state == LoaderHarness::State::kNew ||
442 state_snapshot.state == LoaderHarness::State::kDisabled ||
443 state_snapshot.state == LoaderHarness::State::kError) &&
444 !state_snapshot.additional_state->is_aspired) {
445 const Status status =
446 basic_manager_->StopManagingServable(state_snapshot.id);
448 VLOG(1) <<
"Removed " << state_snapshot.id <<
"from BasicManager";
454 LOG(ERROR) <<
"Error removing " << state_snapshot.id
455 <<
"from BasicManager: " << status <<
" will retry later";
462 void AspiredVersionsManager::HandlePendingAspiredVersionsRequests() {
463 mutex_lock l(basic_manager_read_modify_write_mu_);
464 mutex_lock l2(pending_aspired_versions_requests_mu_);
471 for (
auto it = pending_aspired_versions_requests_.begin();
472 it != pending_aspired_versions_requests_.end();) {
473 const string& servable_name = it->first;
474 std::vector<ServableData<std::unique_ptr<Loader>>>& versions = it->second;
476 if (ContainsAnyReaspiredVersions(servable_name, versions)) {
479 VLOG(1) <<
"Postponing processing of aspired versions request due to "
480 "re-aspired version(s) among: "
481 << ServableVersionsDebugString(versions);
483 ProcessAspiredVersionsRequest(servable_name, std::move(versions));
484 it = pending_aspired_versions_requests_.erase(it);
489 void AspiredVersionsManager::InvokePolicyAndExecuteAction() {
490 mutex_lock l(basic_manager_read_modify_write_mu_);
492 const absl::optional<AspiredVersionPolicy::ServableAction> next_action =
499 PerformAction(*next_action);
502 void AspiredVersionsManager::SetNumLoadThreads(
const uint32 num_load_threads) {
503 basic_manager_->SetNumLoadThreads(num_load_threads);
506 uint32 AspiredVersionsManager::num_load_threads()
const {
507 return basic_manager_->num_load_threads();
@ kUnload
Call unload on the servable.
@ kLoad
Call load on the servable.
std::vector< ServableId > ListAvailableServableIds() const override
Source< std::unique_ptr< Loader > >::AspiredVersionsCallback GetAspiredVersionsCallback() override
Returns a callback to set the list of aspired versions for a particular servable stream,...