TensorFlow Serving C++ API Documentation
saved_model_bundle_source_adapter.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow_serving/resources/resource_util.h"
25 #include "tensorflow_serving/resources/resource_values.h"
26 #include "tensorflow_serving/resources/resources.pb.h"
27 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
28 #include "tensorflow_serving/servables/tensorflow/file_acl.h"
29 #include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
30 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.h"
31 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 
36 Status SavedModelBundleSourceAdapter::Create(
37  const SavedModelBundleSourceAdapterConfig& config,
38  std::unique_ptr<SavedModelBundleSourceAdapter>* adapter) {
39  std::unique_ptr<SavedModelBundleFactory> bundle_factory;
40  TF_RETURN_IF_ERROR(
41  SavedModelBundleFactory::Create(config.legacy_config(), &bundle_factory));
42  adapter->reset(new SavedModelBundleSourceAdapter(std::move(bundle_factory)));
43  return absl::OkStatus();
44 }
45 
46 SavedModelBundleSourceAdapter::~SavedModelBundleSourceAdapter() { Detach(); }
47 
48 SavedModelBundleSourceAdapter::SavedModelBundleSourceAdapter(
49  std::unique_ptr<SavedModelBundleFactory> bundle_factory)
50  : bundle_factory_(std::move(bundle_factory)) {}
51 
52 SimpleLoader<SavedModelBundle>::CreatorVariant
53 SavedModelBundleSourceAdapter::GetServableCreator(
54  std::shared_ptr<SavedModelBundleFactory> bundle_factory,
55  const StoragePath& path) const {
56  if (bundle_factory->config().enable_session_metadata()) {
57  return [bundle_factory, path](const Loader::Metadata& metadata,
58  std::unique_ptr<SavedModelBundle>* bundle) {
59  TF_RETURN_IF_ERROR(RegisterModelRoot(metadata.servable_id, path));
60  TF_RETURN_IF_ERROR(bundle_factory->CreateSavedModelBundleWithMetadata(
61  metadata, path, bundle));
62  MaybePublishMLMDStreamz(path, metadata.servable_id.name,
63  metadata.servable_id.version);
64  if (bundle_factory->config().enable_model_warmup()) {
65  bundle_factory->mutable_config()
66  .mutable_model_warmup_options()
67  ->set_model_name(metadata.servable_id.name);
68  bundle_factory->mutable_config()
69  .mutable_model_warmup_options()
70  ->set_model_version(metadata.servable_id.version);
71  return RunSavedModelWarmup(
72  bundle_factory->config().model_warmup_options(),
73  GetRunOptions(bundle_factory->config()), path, bundle->get());
74  }
75  return absl::OkStatus();
76  };
77  }
78  return [bundle_factory, path](std::unique_ptr<SavedModelBundle>* bundle) {
79  TF_RETURN_IF_ERROR(bundle_factory->CreateSavedModelBundle(path, bundle));
80  if (bundle_factory->config().enable_model_warmup()) {
81  return RunSavedModelWarmup(
82  bundle_factory->config().model_warmup_options(),
83  GetRunOptions(bundle_factory->config()), path, bundle->get());
84  }
85  return absl::OkStatus();
86  };
87 }
88 
89 Status SavedModelBundleSourceAdapter::Convert(const StoragePath& path,
90  std::unique_ptr<Loader>* loader) {
91  std::shared_ptr<SavedModelBundleFactory> bundle_factory = bundle_factory_;
92  auto servable_creator = GetServableCreator(bundle_factory, path);
93  auto resource_estimator = [bundle_factory,
94  path](ResourceAllocation* estimate) {
95  TF_RETURN_IF_ERROR(
96  bundle_factory->EstimateResourceRequirement(path, estimate));
97 
98  // Add experimental_transient_ram_bytes_during_load.
99  // TODO(b/38376838): Remove once resource estimates are moved inside
100  // SavedModel.
101  ResourceUtil::Options resource_util_options;
102  resource_util_options.devices = {{device_types::kMain, 1}};
103  std::unique_ptr<ResourceUtil> resource_util =
104  std::unique_ptr<ResourceUtil>(new ResourceUtil(resource_util_options));
105  const Resource ram_resource = resource_util->CreateBoundResource(
106  device_types::kMain, resource_kinds::kRamBytes);
107  resource_util->SetQuantity(
108  ram_resource,
109  resource_util->GetQuantity(ram_resource, *estimate) +
110  bundle_factory->config()
111  .experimental_transient_ram_bytes_during_load(),
112  estimate);
113 
114  return absl::OkStatus();
115  };
116  auto post_load_resource_estimator = [bundle_factory,
117  path](ResourceAllocation* estimate) {
118  return bundle_factory->EstimateResourceRequirement(path, estimate);
119  };
120  loader->reset(new SimpleLoader<SavedModelBundle>(
121  servable_creator, resource_estimator, {post_load_resource_estimator}));
122  return absl::OkStatus();
123 }
124 
125 // Register the source adapter.
127  public:
128  static Status Create(
129  const SavedModelBundleSourceAdapterConfig& config,
130  std::unique_ptr<SourceAdapter<StoragePath, std::unique_ptr<Loader>>>*
131  adapter) {
132  std::unique_ptr<SavedModelBundleFactory> bundle_factory;
133  TF_RETURN_IF_ERROR(SavedModelBundleFactory::Create(config.legacy_config(),
134  &bundle_factory));
135  adapter->reset(
136  new SavedModelBundleSourceAdapter(std::move(bundle_factory)));
137  return absl::OkStatus();
138  }
139 };
140 REGISTER_STORAGE_PATH_SOURCE_ADAPTER(SavedModelBundleSourceAdapterCreator,
141  SavedModelBundleSourceAdapterConfig);
142 } // namespace serving
143 } // namespace tensorflow
static Status Create(const SessionBundleConfig &config, std::unique_ptr< SavedModelBundleFactory > *factory)