TensorFlow Serving C++ API Documentation
aspired_versions_manager.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/aspired_versions_manager.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/context.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow_serving/core/loader_harness.h"
34 #include "tensorflow_serving/core/servable_handle.h"
35 #include "tensorflow_serving/core/source.h"
36 
37 namespace tensorflow {
38 namespace serving {
39 
40 namespace {
41 
42 // The aspired state stored with every managed servable.
43 //
44 // We use a struct here instead of a naked bool because this buys some type
45 // safety, and cannot be implicitly cast to/from pointers by mistake.
46 struct Aspired {
47  bool is_aspired;
48 };
49 
50 // Decides which action amongst the 2 to take. We prefer an unload action over a
51 // load action.
52 //
53 // Note that this returns a strict weak ordering.
54 struct CompareActions {
55  public:
56  explicit CompareActions(
57  AspiredVersionsManager::CustomSortActionsFn custom_sort_actions)
58  : custom_sort_actions_(custom_sort_actions) {}
59 
60  bool operator()(
61  const absl::optional<AspiredVersionPolicy::ServableAction>& lhs,
62  const absl::optional<AspiredVersionPolicy::ServableAction>& rhs) {
63  if (!lhs) {
64  return false;
65  }
66  if (!rhs) {
67  return true;
68  }
69  // By this point, we are sure the optionals have values.
70  if (custom_sort_actions_) {
71  return custom_sort_actions_(lhs.value(), rhs.value());
72  }
73  return OrderActions(lhs.value(), rhs.value()).action != rhs.value().action;
74  }
75 
76  private:
77  AspiredVersionPolicy::ServableAction OrderActions(
78  const AspiredVersionPolicy::ServableAction& lhs,
79  const AspiredVersionPolicy::ServableAction& rhs) {
80  switch (lhs.action) {
82  return lhs;
84  if (rhs.action == AspiredVersionPolicy::Action::kUnload) {
85  return rhs;
86  }
87  return lhs;
88  }
89  }
90  AspiredVersionsManager::CustomSortActionsFn custom_sort_actions_;
91 };
92 
93 // Validates whether all entries in 'versions' pertain to the servable named
94 // 'servable_name'.
95 Status ValidateAspiredVersions(
96  const StringPiece servable_name,
97  const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
98  for (const auto& version : versions) {
99  if (servable_name != version.id().name) {
100  return errors::InvalidArgument(strings::StrCat(
101  "Servable name: ", servable_name,
102  " doesn't match name in servable version: ", version.id().name));
103  }
104  }
105  return OkStatus();
106 }
107 
108 // Returns the set of version numbers in 'versions'.
109 std::set<int64_t> GetVersionNumbers(
110  const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
111  std::set<int64_t> version_numbers;
112  for (const auto& version : versions) {
113  version_numbers.insert(version.id().version);
114  }
115  return version_numbers;
116 }
117 
118 // Creates a debug string for a given vector of servable versions.
119 string ServableVersionsDebugString(
120  const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) {
121  std::vector<string> version_strings;
122  version_strings.reserve(versions.size());
123  for (const ServableData<std::unique_ptr<Loader>>& version : versions) {
124  version_strings.push_back(version.id().DebugString());
125  }
126  return str_util::Join(version_strings, ", ");
127 }
128 
129 } // namespace
130 
131 namespace internal {
132 
133 // AspiredVersionsManager's implementation of the Target API.
135  : public TargetBase<std::unique_ptr<Loader>> {
136  public:
138  AspiredVersionsManager* const parent)
139  : parent_(parent) {}
140  ~AspiredVersionsManagerTargetImpl() override { Detach(); }
141 
142  protected:
143  void SetAspiredVersions(
144  const StringPiece servable_name,
145  std::vector<ServableData<std::unique_ptr<Loader>>> versions) override {
146  parent_->EnqueueAspiredVersionsRequest(servable_name, std::move(versions));
147  }
148 
149  private:
150  // A pointer to the manager whose Target implementation this is.
151  AspiredVersionsManager* const parent_;
152 
153  TF_DISALLOW_COPY_AND_ASSIGN(AspiredVersionsManagerTargetImpl);
154 };
155 
156 } // namespace internal
157 
158 Status AspiredVersionsManager::Create(
159  Options options, std::unique_ptr<AspiredVersionsManager>* manager) {
160  if (options.aspired_version_policy == nullptr) {
161  return errors::InvalidArgument(
162  "AspiredVersionsManager::Options aspired_version_policy must be "
163  "non-null");
164  }
165  BasicManager::Options basic_manager_options;
166  basic_manager_options.resource_tracker = std::move(options.resource_tracker);
167  basic_manager_options.num_load_threads = options.num_load_threads;
168  basic_manager_options.num_unload_threads = options.num_unload_threads;
169  basic_manager_options.max_num_load_retries = options.max_num_load_retries;
170  basic_manager_options.load_retry_interval_micros =
171  options.load_retry_interval_micros;
172  basic_manager_options.flush_filesystem_caches =
173  options.flush_filesystem_caches;
174  basic_manager_options.env = options.env;
175  basic_manager_options.servable_event_bus = options.servable_event_bus;
176  basic_manager_options.pre_load_hook = std::move(options.pre_load_hook);
177  if (options.should_retry_model_load) {
178  basic_manager_options.should_retry_model_load =
179  std::move(options.should_retry_model_load);
180  }
181  std::unique_ptr<BasicManager> basic_manager;
182  TF_RETURN_IF_ERROR(
183  BasicManager::Create(std::move(basic_manager_options), &basic_manager));
184 
185  manager->reset(new AspiredVersionsManager(
186  options.manage_state_interval_micros, options.env,
187  std::move(options.aspired_version_policy),
188  std::move(options.custom_sort_actions), std::move(basic_manager),
189  options.with_current_context));
190  (manager->get())->enable_reload_servables_with_error_ =
191  options.enable_reload_servables_with_error;
192  return OkStatus();
193 }
194 
195 AspiredVersionsManager::AspiredVersionsManager(
196  int64_t manage_state_interval_micros, Env* env,
197  std::unique_ptr<AspiredVersionPolicy> aspired_version_policy,
198  CustomSortActionsFn custom_sort_actions,
199  std::unique_ptr<BasicManager> basic_manager, bool with_current_context)
200  : aspired_version_policy_(std::move(aspired_version_policy)),
201  custom_sort_actions_(std::move(custom_sort_actions)),
202  target_impl_(new internal::AspiredVersionsManagerTargetImpl(this)),
203  basic_manager_(std::move(basic_manager)) {
204  set_num_load_threads_observer_.reset(
205  new Observer<const uint32>([this](const uint32 num_load_threads) {
206  this->SetNumLoadThreads(num_load_threads);
207  }));
208  if (manage_state_interval_micros > 0) {
209  PeriodicFunction::Options pf_options;
210  pf_options.env = env;
211  pf_options.thread_name_prefix = "AspiredVersionsManager_ManageState_Thread";
212  if (with_current_context) {
213  tensorflow::Context context(tensorflow::ContextKind::kThread);
214  manage_state_thread_.reset(new PeriodicFunction(
215  [this, context = std::move(context)]() {
216  tensorflow::WithContext wc(context);
217  this->FlushServables();
218  this->HandlePendingAspiredVersionsRequests();
219  this->InvokePolicyAndExecuteAction();
220  },
221  manage_state_interval_micros));
222  } else {
223  manage_state_thread_.reset(new PeriodicFunction(
224  [this]() {
225  this->FlushServables();
226  this->HandlePendingAspiredVersionsRequests();
227  this->InvokePolicyAndExecuteAction();
228  },
229  manage_state_interval_micros));
230  }
231  }
232 }
233 
234 AspiredVersionsManager::~AspiredVersionsManager() {
235  // Shut off incoming aspired-versions calls. It is important to do this before
236  // tearing down any other manager state.
237  target_impl_.reset();
238 
239  // This will wait till the thread is joined.
240  manage_state_thread_.reset();
241 }
242 
244  const {
245  return basic_manager_->ListAvailableServableIds();
246 }
247 
248 Status AspiredVersionsManager::GetUntypedServableHandle(
249  const ServableRequest& request,
250  std::unique_ptr<UntypedServableHandle>* const untyped_handle) {
251  return basic_manager_->GetUntypedServableHandle(request, untyped_handle);
252 }
253 
254 std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
255 AspiredVersionsManager::GetAvailableUntypedServableHandles() const {
256  return basic_manager_->GetAvailableUntypedServableHandles();
257 }
258 
259 Source<std::unique_ptr<Loader>>::AspiredVersionsCallback
261  return target_impl_->GetAspiredVersionsCallback();
262 }
263 
264 void AspiredVersionsManager::EnqueueAspiredVersionsRequest(
265  const StringPiece servable_name,
266  std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
267  const Status validation_status =
268  ValidateAspiredVersions(servable_name, versions);
269  DCHECK(validation_status.ok()) << validation_status.message();
270  if (!validation_status.ok()) {
271  LOG(ERROR) << validation_status.message();
272  return;
273  }
274 
275  {
276  mutex_lock l(pending_aspired_versions_requests_mu_);
277  VLOG(2) << "Enqueueing aspired versions request: " << servable_name << ": "
278  << ServableVersionsDebugString(versions);
279  pending_aspired_versions_requests_[string(servable_name)] =
280  std::move(versions);
281  }
282 }
283 
284 void AspiredVersionsManager::ProcessAspiredVersionsRequest(
285  const StringPiece servable_name,
286  std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
287  VLOG(2) << "Processing aspired versions request: " << servable_name << ": "
288  << ServableVersionsDebugString(versions);
289 
290  const std::set<int64_t> next_aspired_versions = GetVersionNumbers(versions);
291 
292  // We gather all the servables with the servable_name and
293  // 1. Add the current aspired version numbers to a set,
294  // 2. Set the aspired bool to false for all current servable harnesses which
295  // are not aspired.
296  std::set<int64_t> current_aspired_versions;
297  std::set<int64_t> current_aspired_versions_with_error;
298  const std::vector<ServableStateSnapshot<Aspired>> state_snapshots =
299  basic_manager_->GetManagedServableStateSnapshots<Aspired>(
300  string(servable_name));
301  for (const ServableStateSnapshot<Aspired>& state_snapshot : state_snapshots) {
302  if (state_snapshot.additional_state->is_aspired) {
303  current_aspired_versions.insert(state_snapshot.id.version);
304  if (state_snapshot.state == LoaderHarness::State::kError) {
305  current_aspired_versions_with_error.insert(state_snapshot.id.version);
306  }
307  }
308  // If this version is not part of the aspired versions.
309  if (next_aspired_versions.find(state_snapshot.id.version) ==
310  next_aspired_versions.end()) {
311  VLOG(1) << "Setting is_aspired=false for " << state_snapshot.id;
312  basic_manager_->GetAdditionalServableState<Aspired>(state_snapshot.id)
313  ->is_aspired = false;
314  basic_manager_->CancelLoadServableRetry(state_snapshot.id);
315  }
316  }
317 
318  // We do a set_difference (A - B), on the next aspired versions and the
319  // current aspired versions to find the version numbers which need to be
320  // added the harness map.
321  std::set<int64_t> additions;
322  std::set_difference(
323  next_aspired_versions.begin(), next_aspired_versions.end(),
324  current_aspired_versions.begin(), current_aspired_versions.end(),
325  std::inserter(additions, additions.begin()));
326 
327  // We go through the aspired_servable_versions, pull out the versions which
328  // need to be added and add them to the harness map.
329  for (auto& version : versions) {
330  bool should_add = false;
331  const auto& version_id = version.id();
332  if (additions.find(version.id().version) != additions.end()) {
333  should_add = true;
334  }
335  if (enable_reload_servables_with_error_ &&
336  current_aspired_versions_with_error.find(version.id().version) !=
337  current_aspired_versions_with_error.end()) {
338  ServableId id;
339  id.name = std::string(servable_name);
340  id.version = version_id.version;
341  const Status manage_status = basic_manager_->StopManagingServable(id);
342  DCHECK(manage_status.ok()) << manage_status.message();
343  if (!manage_status.ok()) {
344  LOG(ERROR) << "Internal error: Unable to clear errored servable "
345  << version_id.DebugString()
346  << " from 'basic_manager_': " << manage_status.message();
347  }
348  should_add = true;
349  }
350 
351  // if this aspired version is not already present in the map.
352  if (should_add) {
353  const Status manage_status =
354  basic_manager_->ManageServableWithAdditionalState(
355  std::move(version), std::unique_ptr<Aspired>(new Aspired{true}));
356  DCHECK(manage_status.ok()) << manage_status.message();
357  if (!manage_status.ok()) {
358  LOG(ERROR) << "Internal error: Unable to transfer servable "
359  << version_id.DebugString()
360  << " to 'basic_manager_': " << manage_status.message();
361  }
362  }
363  }
364 }
365 
366 bool AspiredVersionsManager::ContainsAnyReaspiredVersions(
367  const StringPiece servable_name,
368  const std::vector<ServableData<std::unique_ptr<Loader>>>& versions) const {
369  const std::vector<ServableStateSnapshot<Aspired>> state_snapshots =
370  basic_manager_->GetManagedServableStateSnapshots<Aspired>(
371  string(servable_name));
372  const std::set<int64_t> version_numbers = GetVersionNumbers(versions);
373  for (const ServableStateSnapshot<Aspired>& state_snapshot : state_snapshots) {
374  if (!state_snapshot.additional_state->is_aspired &&
375  version_numbers.find(state_snapshot.id.version) !=
376  version_numbers.end()) {
377  return true;
378  }
379  }
380  return false;
381 }
382 
383 // We collect the version policy actions for each servable stream first. Then
384 // we sort them based on the global policy and pick the first one.
385 absl::optional<AspiredVersionPolicy::ServableAction>
386 AspiredVersionsManager::GetNextAction() {
387  std::vector<absl::optional<AspiredVersionPolicy::ServableAction>> actions;
388  for (const string& servable_name :
389  basic_manager_->GetManagedServableNames()) {
390  std::vector<AspiredServableStateSnapshot> aspired_state_snapshots;
391  for (const ServableStateSnapshot<Aspired>& state_snapshot :
392  basic_manager_->GetManagedServableStateSnapshots<Aspired>(
393  servable_name)) {
394  aspired_state_snapshots.push_back(
395  {state_snapshot.id, state_snapshot.state,
396  state_snapshot.additional_state->is_aspired});
397  }
398  actions.emplace_back(
399  aspired_version_policy_->GetNextAction(aspired_state_snapshots));
400  }
401 
402  std::sort(actions.begin(), actions.end(),
403  CompareActions(custom_sort_actions_));
404  const absl::optional<AspiredVersionPolicy::ServableAction> next_action =
405  !actions.empty() ? actions[0] : absl::nullopt;
406  if (next_action) {
407  VLOG(1) << "Taking action: " << next_action->DebugString();
408  }
409  return next_action;
410 }
411 
412 void AspiredVersionsManager::PerformAction(
413  const AspiredVersionPolicy::ServableAction action) {
414  switch (action.action) {
416  basic_manager_->LoadServable(action.id, [action](const Status& status) {
417  if (!status.ok()) {
418  LOG(ERROR) << "Servable " << action.id.DebugString()
419  << " cannot be loaded: " << status;
420  }
421  });
422  } break;
424  basic_manager_->UnloadServable(action.id, [action](const Status& status) {
425  if (!status.ok()) {
426  LOG(ERROR) << "Servable " << action.id.DebugString()
427  << " cannot be unloaded: " << status;
428  }
429  });
430  } break;
431  }
432 }
433 
434 void AspiredVersionsManager::FlushServables() {
435  mutex_lock l(basic_manager_read_modify_write_mu_);
436  for (const string& servable_name :
437  basic_manager_->GetManagedServableNames()) {
438  for (const ServableStateSnapshot<Aspired>& state_snapshot :
439  basic_manager_->GetManagedServableStateSnapshots<Aspired>(
440  servable_name)) {
441  if ((state_snapshot.state == LoaderHarness::State::kNew ||
442  state_snapshot.state == LoaderHarness::State::kDisabled ||
443  state_snapshot.state == LoaderHarness::State::kError) &&
444  !state_snapshot.additional_state->is_aspired) {
445  const Status status =
446  basic_manager_->StopManagingServable(state_snapshot.id);
447  if (status.ok()) {
448  VLOG(1) << "Removed " << state_snapshot.id << "from BasicManager";
449  } else {
450  // This scenario is likely a bug, perhaps a race (either here in
451  // AspiredVersionsManager, or in BasicManager). We'll wind up retrying
452  // StopManagingServable() on the next FlushServables() call, so just
453  // log the error and move on for now.
454  LOG(ERROR) << "Error removing " << state_snapshot.id
455  << "from BasicManager: " << status << " will retry later";
456  }
457  }
458  }
459  }
460 }
461 
462 void AspiredVersionsManager::HandlePendingAspiredVersionsRequests() {
463  mutex_lock l(basic_manager_read_modify_write_mu_);
464  mutex_lock l2(pending_aspired_versions_requests_mu_);
465 
466  // To be able to process an aspired-versions request, we wait for any
467  // re-aspired versions (versions not currently marked aspired, but present in
468  // the latest aspired-versions request) to quiesce and be removed from
469  // BasicManager. If an enqueued request does contain re-aspired versions, we
470  // simply leave it in the queue for now.
471  for (auto it = pending_aspired_versions_requests_.begin();
472  it != pending_aspired_versions_requests_.end();) {
473  const string& servable_name = it->first;
474  std::vector<ServableData<std::unique_ptr<Loader>>>& versions = it->second;
475 
476  if (ContainsAnyReaspiredVersions(servable_name, versions)) {
477  // Sit on it for now. We'll check again later.
478  ++it;
479  VLOG(1) << "Postponing processing of aspired versions request due to "
480  "re-aspired version(s) among: "
481  << ServableVersionsDebugString(versions);
482  } else {
483  ProcessAspiredVersionsRequest(servable_name, std::move(versions));
484  it = pending_aspired_versions_requests_.erase(it);
485  }
486  }
487 }
488 
489 void AspiredVersionsManager::InvokePolicyAndExecuteAction() {
490  mutex_lock l(basic_manager_read_modify_write_mu_);
491 
492  const absl::optional<AspiredVersionPolicy::ServableAction> next_action =
493  GetNextAction();
494  if (!next_action) {
495  return;
496  }
497  // NOTE: we could do action validation here.
498 
499  PerformAction(*next_action);
500 }
501 
502 void AspiredVersionsManager::SetNumLoadThreads(const uint32 num_load_threads) {
503  basic_manager_->SetNumLoadThreads(num_load_threads);
504 }
505 
506 uint32 AspiredVersionsManager::num_load_threads() const {
507  return basic_manager_->num_load_threads();
508 }
509 
510 } // namespace serving
511 } // namespace tensorflow
std::vector< ServableId > ListAvailableServableIds() const override
Source< std::unique_ptr< Loader > >::AspiredVersionsCallback GetAspiredVersionsCallback() override
Returns a callback to set the list of aspired versions for a particular servable stream,...