TensorFlow Serving C++ API Documentation
bundle_factory_util.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_UTIL_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_UTIL_H_
18 
19 #include "google/protobuf/wrappers.pb.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/core/public/session_options.h"
26 #include "tensorflow_serving/batching/batching_session.h"
27 #include "tensorflow_serving/resources/resources.pb.h"
28 #include "tensorflow_serving/servables/tensorflow/resource_estimator.h"
29 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
30 #include "tensorflow_serving/util/file_probing_env.h"
31 
32 namespace tensorflow {
33 namespace serving {
34 
35 // Returns SessionOptions based on the SessionBundleConfig.
36 // TODO(b/32248363): add SavedModelBundleConfig after we switch Model Server to
37 // Saved Model.
38 SessionOptions GetSessionOptions(const SessionBundleConfig& config);
39 
40 // Returns RunOptions based on SessionBundleConfig.
41 // TODO(b/32248363): add SavedModelBundleConfig after we switch Model Server to
42 // Saved Model.
43 RunOptions GetRunOptions(const SessionBundleConfig& config);
44 
45 // Get per-model batching parameters if they are present.
46 //
47 // When `per_model_configured` is true we return model specific batching
48 // parameters from `batching_params.pbtxt` file in SavedModel dir under `path`
49 // if one exists. If `per_model_configured` is false we return `common_params`.
50 // Failure to parse model specific params will return error.
51 Status GetPerModelBatchingParams(const string& path,
52  const BatchingParameters& common_params,
53  bool per_model_configured,
54  absl::optional<BatchingParameters>* params);
55 
56 // Creates a BatchScheduler based on the batching configuration.
57 template <typename TaskType>
58 Status CreateBatchScheduler(
59  const BatchingParameters& batching_config,
60  std::shared_ptr<SharedBatchScheduler<TaskType>>* batch_scheduler) {
61  typename SharedBatchScheduler<TaskType>::Options options;
62  if (batching_config.has_num_batch_threads()) {
63  options.num_batch_threads = batching_config.num_batch_threads().value();
64  }
65  if (batching_config.has_thread_pool_name()) {
66  options.thread_pool_name = batching_config.thread_pool_name().value();
67  }
68  return SharedBatchScheduler<TaskType>::Create(options, batch_scheduler);
69 }
70 
71 // Estimates the resources a session bundle or saved model bundle will use once
72 // loaded, from infra validation.
73 Status EstimateResourceFromValidationResult(const string& path,
74  ResourceAllocation* estimate);
75 
76 // Estimates the resources a session bundle or saved model bundle will use once
77 // loaded, from its export or saved model path. tensorflow::Env::Default() will
78 // be used to access the file system.
79 //
80 // If use_validation_result = true, tries to use the result from infra validtion
81 // first. Otherwise, uses the following crude heuristic: estimated main-memory
82 // RAM = (combined size of all exported file(s)) *
83 // kResourceEstimateRAMMultiplier + kResourceEstimateRAMPadBytes.
84 // TODO(b/27694447): Improve the heuristic. At a minimum, account for GPU RAM.
85 Status EstimateResourceFromPath(const string& path, bool use_validation_result,
86  ResourceAllocation* estimate);
87 
88 // Wraps a session in a new session that automatically batches Run() calls.
89 Status WrapSessionForBatching(
90  const BatchingParameters& batching_config,
91  std::shared_ptr<SharedBatchScheduler<BatchingSessionTask>> batch_scheduler,
92  const std::vector<SignatureDef>& signatures,
93  std::unique_ptr<Session>* session);
94 
95 // Wraps a session in a new session that only supports Run() without batching.
96 Status WrapSession(std::unique_ptr<Session>* session);
97 
98 // Wraps a session in a new session that only supports Run() without threading
99 // parameters.
100 Status WrapSessionIgnoreThreadPoolOptions(std::unique_ptr<Session>* session);
101 
102 // Construct Queue Options from BatchingParameters.
103 template <typename TaskType>
104 typename SharedBatchScheduler<TaskType>::QueueOptions GetQueueOptions(
105  const BatchingParameters& batching_config,
106  std::function<Status(std::unique_ptr<TaskType>* input_task,
107  int first_output_task_size, int input_batch_size_limit,
108  std::vector<std::unique_ptr<TaskType>>* output_tasks)>
109  split_input_task_func) {
110  typename SharedBatchScheduler<TaskType>::QueueOptions queue_options;
111  if (batching_config.has_max_batch_size()) {
112  queue_options.input_batch_size_limit =
113  batching_config.max_batch_size().value();
114  }
115  if (batching_config.has_batch_timeout_micros()) {
116  queue_options.batch_timeout_micros =
117  batching_config.batch_timeout_micros().value();
118  }
119  if (batching_config.has_max_enqueued_batches()) {
120  queue_options.max_enqueued_batches =
121  batching_config.max_enqueued_batches().value();
122  }
123  if (batching_config.has_enable_large_batch_splitting() &&
124  batching_config.enable_large_batch_splitting().value()) {
125  queue_options.enable_large_batch_splitting = true;
126 
127  if (batching_config.has_max_execution_batch_size()) {
128  queue_options.max_execution_batch_size =
129  batching_config.max_execution_batch_size().value();
130  } else {
131  queue_options.max_execution_batch_size =
132  batching_config.max_batch_size().value();
133  }
134 
135  queue_options.split_input_task_func = split_input_task_func;
136  }
137  return queue_options;
138 }
139 
140 } // namespace serving
141 } // namespace tensorflow
142 
143 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_UTIL_H_