16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_UTIL_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_UTIL_H_
19 #include "google/protobuf/wrappers.pb.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow/core/public/session_options.h"
26 #include "tensorflow_serving/batching/batching_session.h"
27 #include "tensorflow_serving/resources/resources.pb.h"
28 #include "tensorflow_serving/servables/tensorflow/resource_estimator.h"
29 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
30 #include "tensorflow_serving/util/file_probing_env.h"
32 namespace tensorflow {
38 SessionOptions GetSessionOptions(
const SessionBundleConfig& config);
43 RunOptions GetRunOptions(
const SessionBundleConfig& config);
51 Status GetPerModelBatchingParams(
const string& path,
52 const BatchingParameters& common_params,
53 bool per_model_configured,
54 absl::optional<BatchingParameters>* params);
57 template <
typename TaskType>
58 Status CreateBatchScheduler(
59 const BatchingParameters& batching_config,
60 std::shared_ptr<SharedBatchScheduler<TaskType>>* batch_scheduler) {
61 typename SharedBatchScheduler<TaskType>::Options options;
62 if (batching_config.has_num_batch_threads()) {
63 options.num_batch_threads = batching_config.num_batch_threads().value();
65 if (batching_config.has_thread_pool_name()) {
66 options.thread_pool_name = batching_config.thread_pool_name().value();
68 return SharedBatchScheduler<TaskType>::Create(options, batch_scheduler);
73 Status EstimateResourceFromValidationResult(
const string& path,
74 ResourceAllocation* estimate);
85 Status EstimateResourceFromPath(
const string& path,
bool use_validation_result,
86 ResourceAllocation* estimate);
89 Status WrapSessionForBatching(
90 const BatchingParameters& batching_config,
91 std::shared_ptr<SharedBatchScheduler<BatchingSessionTask>> batch_scheduler,
92 const std::vector<SignatureDef>& signatures,
93 std::unique_ptr<Session>* session);
96 Status WrapSession(std::unique_ptr<Session>* session);
100 Status WrapSessionIgnoreThreadPoolOptions(std::unique_ptr<Session>* session);
103 template <
typename TaskType>
104 typename SharedBatchScheduler<TaskType>::QueueOptions GetQueueOptions(
105 const BatchingParameters& batching_config,
106 std::function<Status(std::unique_ptr<TaskType>* input_task,
107 int first_output_task_size,
int input_batch_size_limit,
108 std::vector<std::unique_ptr<TaskType>>* output_tasks)>
109 split_input_task_func) {
110 typename SharedBatchScheduler<TaskType>::QueueOptions queue_options;
111 if (batching_config.has_max_batch_size()) {
112 queue_options.input_batch_size_limit =
113 batching_config.max_batch_size().value();
115 if (batching_config.has_batch_timeout_micros()) {
116 queue_options.batch_timeout_micros =
117 batching_config.batch_timeout_micros().value();
119 if (batching_config.has_max_enqueued_batches()) {
120 queue_options.max_enqueued_batches =
121 batching_config.max_enqueued_batches().value();
123 if (batching_config.has_enable_large_batch_splitting() &&
124 batching_config.enable_large_batch_splitting().value()) {
125 queue_options.enable_large_batch_splitting =
true;
127 if (batching_config.has_max_execution_batch_size()) {
128 queue_options.max_execution_batch_size =
129 batching_config.max_execution_batch_size().value();
131 queue_options.max_execution_batch_size =
132 batching_config.max_batch_size().value();
135 queue_options.split_input_task_func = split_input_task_func;
137 return queue_options;