TensorFlow Serving C++ API Documentation
saved_model_bundle_factory.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.h"
17 
18 #include <memory>
19 #include <unordered_set>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/cc/saved_model/tag_constants.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/protobuf/config.pb.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
34 #include "tensorflow_serving/servables/tensorflow/tflite_session.h"
35 #include "tensorflow_serving/session_bundle/session_bundle_util.h"
36 
37 namespace tensorflow {
38 namespace serving {
39 
40 namespace {
41 
42 // Extracts the signatures from 'bundle'.
43 std::vector<SignatureDef> GetSignatureDefs(const SavedModelBundle& bundle) {
44  std::vector<SignatureDef> signature_defs;
45  for (const auto& entry : bundle.meta_graph_def.signature_def()) {
46  const SignatureDef& signature_def = entry.second;
47  signature_defs.push_back(signature_def);
48  }
49  return signature_defs;
50 }
51 
52 // TODO(b/140959776): Move this upstream alongside `kSavedModelFilenamePb`.
53 const char kTfLiteModelFilename[] = "model.tflite";
54 
55 Status LoadTfLiteModel(const string& model_dir, SavedModelBundle* bundle,
56  const SessionOptions& options, int num_interpreter_pools,
57  int num_interpreters_per_pool) {
58  std::unique_ptr<TfLiteSession> session;
59 
60  const string& fname = io::JoinPath(model_dir, kTfLiteModelFilename);
61  uint64_t size;
62  TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(fname, &size));
63 
64  std::unique_ptr<RandomAccessFile> file;
65  TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(fname, &file));
66 
67  string model_bytes;
68  model_bytes.resize(size);
69  absl::string_view sv;
70  TF_RETURN_IF_ERROR(file->Read(0, size, &sv, &model_bytes[0]));
71 
72  std::unique_ptr<TfLiteSession> tflite_session;
73  TF_RETURN_IF_ERROR(TfLiteSession::Create(
74  std::move(model_bytes), options, num_interpreter_pools,
75  num_interpreters_per_pool, &tflite_session,
76  bundle->meta_graph_def.mutable_signature_def()));
77  bundle->session = std::move(tflite_session);
78  return absl::OkStatus();
79 }
80 
81 bool TfLiteModelFound(const string& model_dir) {
82  const string& fname = io::JoinPath(model_dir, kTfLiteModelFilename);
83  return Env::Default()->FilesExist({fname}, nullptr);
84 }
85 
86 } // namespace
87 
89  const SessionBundleConfig& config,
90  std::unique_ptr<SavedModelBundleFactory>* factory) {
91  std::shared_ptr<Batcher> batcher;
92  if (config.has_batching_parameters()) {
93  TF_RETURN_IF_ERROR(
94  CreateBatchScheduler(config.batching_parameters(), &batcher));
95  }
96  factory->reset(new SavedModelBundleFactory(config, batcher));
97  return absl::OkStatus();
98 }
99 
101  const string& path, ResourceAllocation* estimate) const {
102  return EstimateResourceFromPath(
103  path, config_.resource_estimation_uses_validation_result(), estimate);
104 }
105 
107  const Loader::Metadata& metadata, const string& path,
108  std::unique_ptr<SavedModelBundle>* bundle) {
109  return InternalCreateSavedModelBundle(metadata, path, bundle);
110 }
111 
113  const string& path, std::unique_ptr<SavedModelBundle>* bundle) {
114  return InternalCreateSavedModelBundle({}, path, bundle);
115 }
116 
117 Status SavedModelBundleFactory::InternalCreateSavedModelBundle(
118  const absl::optional<Loader::Metadata>& metadata, const string& path,
119  std::unique_ptr<SavedModelBundle>* bundle) {
120  bundle->reset(new SavedModelBundle);
121  std::unordered_set<string> saved_model_tags(
122  config_.saved_model_tags().begin(), config_.saved_model_tags().end());
123  // Defaults to loading the meta graph def corresponding to the `serve` tag
124  // if no `saved_model_tags` are specified.
125  if (saved_model_tags.empty()) {
126  saved_model_tags.insert(kSavedModelTagServe);
127  }
128  const auto& session_options = [&]() {
129  auto result = GetSessionOptions(config_);
130  string mixed_precision_value = config_.mixed_precision();
131  if (!mixed_precision_value.empty()) {
132  if (mixed_precision_value == "bfloat16") {
133  LOG(INFO) << "Running inference with bfloat16 auto mixed precision";
134  tensorflow::ConfigProto& config = result.config;
135  GraphOptions* gopt = config.mutable_graph_options();
136  RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
137  rwcfg->set_auto_mixed_precision_onednn_bfloat16(RewriterConfig::ON);
138  } else {
139  LOG(WARNING)
140  << config_.mixed_precision()
141  << " auto mixed precision is not supported. Valid option: bfloat16";
142  }
143  }
144  if (metadata.has_value()) {
145  auto* session_metadata =
146  result.config.mutable_experimental()->mutable_session_metadata();
147  session_metadata->set_name(metadata->servable_id.name);
148  session_metadata->set_version(metadata->servable_id.version);
149  }
150  return result;
151  }();
152 
153  bool is_tflite = config_.prefer_tflite_model() && TfLiteModelFound(path);
154  if (is_tflite) {
155  int num_tflite_pools = config_.num_tflite_pools();
156  if (num_tflite_pools == 0 && config_.num_tflite_interpreters() > 0) {
157  num_tflite_pools = config_.num_tflite_interpreters();
158  }
159  TF_RETURN_IF_ERROR(LoadTfLiteModel(
160  path, bundle->get(), session_options, num_tflite_pools,
161  config_.num_tflite_interpreters_per_pool()));
162  } else {
163  TF_RETURN_IF_ERROR(session_bundle::LoadSessionBundleOrSavedModelBundle(
164  session_options, GetRunOptions(config_), path, saved_model_tags,
165  config_.enable_saved_model_config(), bundle->get()));
166  }
167  if (config_.remove_unused_fields_from_bundle_metagraph()) {
168  // Save memory by removing fields in MetaGraphDef proto message stored
169  // in the bundle that we never use. Notably the unused graphdef submessage
170  // can get large (MBs) wasting memory on the server.
171  //
172  // Presently we retain following field(s) of MetaGraphDef proto:
173  // - signature_def
174  MetaGraphDef metagraph;
175  (*bundle)->meta_graph_def.Swap(&metagraph);
176  (*bundle)->meta_graph_def.mutable_signature_def()->swap(
177  *metagraph.mutable_signature_def());
178  }
179  if (config_.wrap_session_with_no_threading_params()) {
180  return WrapSessionIgnoreThreadPoolOptions(&(*bundle)->session);
181  } else if (config_.has_batching_parameters()) {
182  absl::optional<BatchingParameters> batching_params;
183  TF_RETURN_IF_ERROR(GetPerModelBatchingParams(
184  path, config_.batching_parameters(),
185  config_.enable_per_model_batching_params(), &batching_params));
186  if (batching_params.has_value()) {
187  // Enable batching of requests to any one signature_def in the SavedModel.
188  // Note that in the future, the plan is to enable explicit configuration
189  // of the one or many SignatureDefs to enable.
190  const std::vector<SignatureDef> signatures = GetSignatureDefs(**bundle);
191  return WrapSessionForBatching(batching_params.value(), batch_scheduler_,
192  signatures, &(*bundle)->session);
193  }
194  }
195  return WrapSession(&(*bundle)->session);
196 }
197 
198 SavedModelBundleFactory::SavedModelBundleFactory(
199  const SessionBundleConfig& config, std::shared_ptr<Batcher> batch_scheduler)
200  : config_(config), batch_scheduler_(batch_scheduler) {}
201 
202 } // namespace serving
203 } // namespace tensorflow
static Status Create(const SessionBundleConfig &config, std::unique_ptr< SavedModelBundleFactory > *factory)
Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
Status CreateSavedModelBundleWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< SavedModelBundle > *bundle)
Status CreateSavedModelBundle(const string &path, std::unique_ptr< SavedModelBundle > *bundle)
The metadata consists of the ServableId.
Definition: loader.h:94