16 #include "tensorflow_serving/sources/storage_path/file_system_storage_path_source.h"
24 #include <unordered_set>
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/io/path.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tsl/platform/errors.h"
35 #include "tsl/platform/macros.h"
36 #include "tensorflow_serving/core/servable_data.h"
37 #include "tensorflow_serving/core/servable_id.h"
39 namespace tensorflow {
42 FileSystemStoragePathSource::~FileSystemStoragePathSource() {
46 fs_polling_thread_.reset();
53 std::set<string> GetDeletedServables(
54 const FileSystemStoragePathSourceConfig& old_config,
55 const FileSystemStoragePathSourceConfig& new_config) {
56 std::set<string> new_servables;
57 for (
const FileSystemStoragePathSourceConfig::ServableToMonitor& servable :
58 new_config.servables()) {
59 new_servables.insert(servable.servable_name());
62 std::set<string> deleted_servables;
63 for (
const FileSystemStoragePathSourceConfig::ServableToMonitor&
64 old_servable : old_config.servables()) {
65 if (new_servables.find(old_servable.servable_name()) ==
66 new_servables.end()) {
67 deleted_servables.insert(old_servable.servable_name());
70 return deleted_servables;
76 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
77 const string& version_relative_path,
const int64_t version_number,
78 std::vector<ServableData<StoragePath>>* versions) {
79 const ServableId servable_id = {servable.servable_name(), version_number};
80 const string full_path =
81 io::JoinPath(servable.base_path(), version_relative_path);
82 versions->emplace_back(ServableData<StoragePath>(servable_id, full_path));
87 bool ParseVersionNumber(
const string& version_path, int64_t* version_number) {
88 return strings::safe_strto64(version_path.c_str(), version_number);
97 bool AspireAllVersions(
98 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
99 const std::vector<string>& children,
100 std::vector<ServableData<StoragePath>>* versions) {
101 bool at_least_one_version_found =
false;
102 for (
const string& child : children) {
105 int64_t version_number;
106 if (ParseVersionNumber(child, &version_number)) {
108 AspireVersion(servable, child, version_number, versions);
109 at_least_one_version_found =
true;
113 return at_least_one_version_found;
119 std::map<int64_t ,
string >
120 IndexChildrenByVersion(
const std::vector<string>& children) {
121 std::map<int64_t, string> children_by_version;
122 for (
int i = 0; i < children.size(); ++i) {
123 int64_t version_number;
124 if (!ParseVersionNumber(children[i], &version_number)) {
128 if (children_by_version.count(version_number) > 0) {
129 LOG(WARNING) <<
"Duplicate version directories detected. Version "
130 << version_number <<
" will be loaded from " << children[i]
131 <<
", " << children_by_version[version_number]
132 <<
" will be ignored.";
134 children_by_version[version_number] = children[i];
136 return children_by_version;
144 bool AspireLatestVersions(
145 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
146 const std::map<int64_t, string>& children_by_version,
147 std::vector<ServableData<StoragePath>>* versions) {
148 const int32 num_servable_versions_to_serve =
149 std::max(servable.servable_version_policy().latest().num_versions(), 1U);
152 int num_versions_emitted = 0;
153 for (
auto rit = children_by_version.rbegin();
154 rit != children_by_version.rend(); ++rit) {
155 if (num_versions_emitted == num_servable_versions_to_serve) {
158 const int64_t version = rit->first;
159 const string& child = rit->second;
160 AspireVersion(servable, child, version, versions);
161 num_versions_emitted++;
164 return !children_by_version.empty();
174 bool AspireSpecificVersionsFastPath(
175 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
176 std::vector<ServableData<StoragePath>>* versions) {
177 if (servable.servable_version_policy().specific().versions().empty()) {
181 LOG(WARNING) <<
"No specific versions requested for servable "
182 << servable.servable_name() <<
".";
189 for (
const int64_t version :
190 servable.servable_version_policy().specific().versions()) {
191 const string version_dir = absl::StrCat(version);
192 const string child_dir = io::JoinPath(servable.base_path(), version_dir);
194 const absl::Status status = Env::Default()->FileExists(child_dir);
201 for (
const int64_t version :
202 servable.servable_version_policy().specific().versions()) {
203 const string version_dir = absl::StrCat(version);
204 AspireVersion(servable, version_dir, version, versions);
215 bool AspireSpecificVersions(
216 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
217 const std::map<int64_t, string>& children_by_version,
218 std::vector<ServableData<StoragePath>>* versions) {
219 const std::unordered_set<int64_t> versions_to_serve(
220 servable.servable_version_policy().specific().versions().begin(),
221 servable.servable_version_policy().specific().versions().end());
225 std::unordered_set<int64_t> aspired_versions;
226 for (
auto it = children_by_version.begin(); it != children_by_version.end();
228 const int64_t version = it->first;
229 if (versions_to_serve.count(version) == 0) {
232 const string& child = it->second;
233 AspireVersion(servable, child, version, versions);
234 aspired_versions.insert(version);
236 for (
const int64_t version : versions_to_serve) {
237 if (aspired_versions.count(version) == 0) {
239 <<
"Version " << version <<
" of servable "
240 << servable.servable_name() <<
", which was requested to be served "
241 <<
"as a 'specific' version in the servable's version policy, was "
242 <<
"not found in the file system";
246 return !aspired_versions.empty();
250 Status PollFileSystemForServable(
251 const FileSystemStoragePathSourceConfig::ServableToMonitor& servable,
252 std::vector<ServableData<StoragePath>>* versions) {
257 Status status = Env::Default()->FileExists(servable.base_path());
259 return errors::InvalidArgument(
260 "Could not find base path ", servable.base_path(),
" for servable ",
261 servable.servable_name(),
" with error ", status.ToString());
264 if (servable.servable_version_policy().policy_choice_case() ==
265 FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific) {
268 if (AspireSpecificVersionsFastPath(servable, versions)) {
270 return absl::OkStatus();
275 std::vector<string> children;
277 Env::Default()->GetChildren(servable.base_path(), &children));
281 std::set<string> real_children;
282 for (
int i = 0; i < children.size(); ++i) {
283 const string& child = children[i];
284 real_children.insert(child.substr(0, child.find_first_of(
'/')));
287 children.insert(children.begin(), real_children.begin(), real_children.end());
288 const std::map<int64_t ,
string >
289 children_by_version = IndexChildrenByVersion(children);
291 bool at_least_one_version_found =
false;
292 switch (servable.servable_version_policy().policy_choice_case()) {
293 case FileSystemStoragePathSourceConfig::ServableVersionPolicy::
294 POLICY_CHOICE_NOT_SET:
295 TF_FALLTHROUGH_INTENDED;
296 case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kLatest:
297 at_least_one_version_found =
298 AspireLatestVersions(servable, children_by_version, versions);
300 case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kAll:
301 at_least_one_version_found =
302 AspireAllVersions(servable, children, versions);
304 case FileSystemStoragePathSourceConfig::ServableVersionPolicy::kSpecific:
305 at_least_one_version_found =
306 AspireSpecificVersions(servable, children_by_version, versions);
309 return errors::Internal(
"Unhandled servable version_policy: ",
310 servable.servable_version_policy().DebugString());
313 if (!at_least_one_version_found) {
314 LOG(WARNING) <<
"No versions of servable " << servable.servable_name()
315 <<
" found under base path " << servable.base_path()
316 <<
". Did you forget to name your leaf directory as a number "
326 Status PollFileSystemForConfig(
327 const FileSystemStoragePathSourceConfig& config,
328 std::map<
string, std::vector<ServableData<StoragePath>>>*
329 versions_by_servable_name) {
330 for (
const FileSystemStoragePathSourceConfig::ServableToMonitor& servable :
331 config.servables()) {
332 std::vector<ServableData<StoragePath>> versions;
333 TF_RETURN_IF_ERROR(PollFileSystemForServable(servable, &versions));
334 versions_by_servable_name->insert(
335 {servable.servable_name(), std::move(versions)});
342 Status FailIfZeroVersions(
const FileSystemStoragePathSourceConfig& config) {
343 std::map<string, std::vector<ServableData<StoragePath>>>
344 versions_by_servable_name;
346 PollFileSystemForConfig(config, &versions_by_servable_name));
348 std::map<string, string> servable_name_to_base_path_map;
349 for (
const FileSystemStoragePathSourceConfig::ServableToMonitor& servable :
350 config.servables()) {
351 servable_name_to_base_path_map.insert(
352 {servable.servable_name(), servable.base_path()});
355 for (
const auto& entry : versions_by_servable_name) {
356 const string& servable = entry.first;
357 const std::vector<ServableData<StoragePath>>& versions = entry.second;
358 if (versions.empty()) {
359 return errors::NotFound(
360 "Unable to find a numerical version path for servable ", servable,
361 " at: ", servable_name_to_base_path_map[servable]);
369 Status FileSystemStoragePathSource::Create(
370 const FileSystemStoragePathSourceConfig& config,
371 std::unique_ptr<FileSystemStoragePathSource>* result) {
372 result->reset(
new FileSystemStoragePathSource());
373 return (*result)->UpdateConfig(config);
377 const FileSystemStoragePathSourceConfig& config) {
380 if (fs_polling_thread_ !=
nullptr &&
381 config.file_system_poll_wait_seconds() !=
382 config_.file_system_poll_wait_seconds()) {
383 return errors::InvalidArgument(
384 "Changing file_system_poll_wait_seconds is not supported");
387 if (config.fail_if_zero_versions_at_startup() ||
388 config.servable_versions_always_present()) {
389 TF_RETURN_IF_ERROR(FailIfZeroVersions(config));
392 if (aspired_versions_callback_) {
393 TF_RETURN_IF_ERROR(UnaspireServables(GetDeletedServables(config_, config)));
400 void FileSystemStoragePathSource::SetAspiredVersionsCallback(
401 AspiredVersionsCallback callback) {
404 if (fs_polling_thread_ !=
nullptr) {
405 LOG(ERROR) <<
"SetAspiredVersionsCallback() called multiple times; "
406 "ignoring this call";
410 aspired_versions_callback_ = callback;
412 const auto thread_fn = [
this](void) {
413 Status status = this->PollFileSystemAndInvokeCallback();
415 LOG(ERROR) <<
"FileSystemStoragePathSource encountered a "
416 "filesystem access error: "
421 if (config_.file_system_poll_wait_seconds() == 0) {
423 fs_polling_thread_.reset(
new FileSystemStoragePathSource::ThreadType(
424 absl::in_place_type_t<std::unique_ptr<Thread>>(),
425 Env::Default()->StartThread(
427 "FileSystemStoragePathSource_filesystem_oneshot_thread",
429 }
else if (config_.file_system_poll_wait_seconds() > 0) {
431 PeriodicFunction::Options pf_options;
432 pf_options.thread_name_prefix =
433 "FileSystemStoragePathSource_filesystem_polling_thread";
434 fs_polling_thread_.reset(
new FileSystemStoragePathSource::ThreadType(
435 absl::in_place_type_t<PeriodicFunction>(), thread_fn,
436 config_.file_system_poll_wait_seconds() * 1000000, pf_options));
440 Status FileSystemStoragePathSource::PollFileSystemAndInvokeCallback() {
442 std::map<string, std::vector<ServableData<StoragePath>>>
443 versions_by_servable_name;
445 PollFileSystemForConfig(config_, &versions_by_servable_name));
446 for (
const auto& entry : versions_by_servable_name) {
447 const string& servable = entry.first;
448 const std::vector<ServableData<StoragePath>>& versions = entry.second;
449 if (versions.empty() && config_.servable_versions_always_present()) {
450 LOG(ERROR) <<
"Refusing to unload all versions for Servable: "
454 for (
const ServableData<StoragePath>& version : versions) {
455 if (version.status().ok()) {
456 VLOG(1) <<
"File-system polling update: Servable:" << version.id()
457 <<
"; Servable path: " << version.DataOrDie()
458 <<
"; Polling frequency: "
459 << config_.file_system_poll_wait_seconds();
462 CallAspiredVersionsCallback(servable, versions);
467 Status FileSystemStoragePathSource::UnaspireServables(
468 const std::set<string>& servable_names) {
469 for (
const string& servable_name : servable_names) {
470 CallAspiredVersionsCallback(servable_name,
471 std::vector<ServableData<StoragePath>>{});
Status UpdateConfig(const FileSystemStoragePathSourceConfig &config)