16 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
23 #include "google/protobuf/wrappers.pb.h"
24 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow_serving/batching/batching_session.h"
30 #include "tensorflow_serving/resources/resource_values.h"
31 #include "tensorflow_serving/servables/tensorflow/serving_session.h"
32 #include "tensorflow_serving/util/proto_util.h"
34 namespace tensorflow {
39 using Batcher = SharedBatchScheduler<BatchingSessionTask>;
41 const char kBatchingParamsFilename[] =
"batching_params.pbtxt";
43 bool BatchingParamsFound(
const string& model_dir) {
44 const string& fname = io::JoinPath(model_dir, kBatchingParamsFilename);
45 return Env::Default()->FilesExist({fname},
nullptr);
50 SessionOptions GetSessionOptions(
const SessionBundleConfig& config) {
51 SessionOptions options;
52 options.target = config.session_target();
53 options.config = config.session_config();
57 RunOptions GetRunOptions(
const SessionBundleConfig& config) {
58 RunOptions run_options;
59 if (config.has_session_run_load_threadpool_index()) {
60 run_options.set_inter_op_thread_pool(
61 config.session_run_load_threadpool_index().value());
66 Status GetPerModelBatchingParams(
const string& path,
67 const BatchingParameters& common_params,
68 bool per_model_configured,
69 absl::optional<BatchingParameters>* params) {
70 if (per_model_configured) {
71 if (BatchingParamsFound(path)) {
72 *params = absl::make_optional(BatchingParameters());
73 TF_RETURN_IF_ERROR(ParseProtoTextFile(
74 io::JoinPath(path, kBatchingParamsFilename), ¶ms->value()));
75 VLOG(1) <<
"Wrapping session to perform batch processing "
76 <<
"using SavedModel batching params: "
77 << params->value().DebugString();
80 *params = absl::make_optional(common_params);
81 VLOG(1) <<
"Wrapping session to perform batch processing "
82 <<
"using session config batching params: "
83 << params->value().DebugString();
85 return absl::OkStatus();
88 Status EstimateResourceFromValidationResult(
const string& path,
89 ResourceAllocation* estimate) {
90 return EstimateMainRamBytesFromValidationResult(path, estimate);
93 Status EstimateResourceFromPath(
const string& path,
bool use_validation_result,
94 ResourceAllocation* estimate) {
95 TensorflowFileProbingEnv env(Env::Default());
96 return EstimateMainRamBytesFromPath(path, use_validation_result, &env,
100 Status WrapSessionForBatching(
const BatchingParameters& batching_config,
101 std::shared_ptr<Batcher> batch_scheduler,
102 const std::vector<SignatureDef>& signatures,
103 std::unique_ptr<Session>* session) {
104 LOG(INFO) <<
"Wrapping session to perform batch processing";
106 if (batch_scheduler ==
nullptr) {
107 return errors::Internal(
"batch_scheduler not set");
109 if (*session ==
nullptr) {
110 return errors::Internal(
"session not set");
113 if (!batching_config.allowed_batch_sizes().empty()) {
115 const int last_allowed_size = batching_config.allowed_batch_sizes(
116 batching_config.allowed_batch_sizes().size() - 1);
117 const int max_size = batching_config.has_max_batch_size()
118 ? batching_config.max_batch_size().value()
119 : Batcher::QueueOptions().input_batch_size_limit;
120 if (last_allowed_size != max_size) {
121 return errors::InvalidArgument(
122 "Last entry in allowed_batch_sizes must match max_batch_size; last "
124 last_allowed_size,
"; expected ", max_size);
128 auto queue_options = GetQueueOptions<
131 [](std::unique_ptr<tensorflow::serving::BatchingSessionTask>* input_task,
132 int open_batch_remaining_slot,
int max_batch_size,
133 std::vector<std::unique_ptr<tensorflow::serving::BatchingSessionTask>>*
134 output_tasks) -> tensorflow::Status {
135 return SplitInputTask(input_task, open_batch_remaining_slot,
136 max_batch_size, output_tasks);
139 BatchingSessionOptions batching_session_options;
140 for (
int allowed_batch_size : batching_config.allowed_batch_sizes()) {
141 batching_session_options.allowed_batch_sizes.push_back(allowed_batch_size);
144 batching_session_options.pad_variable_length_inputs =
145 batching_config.pad_variable_length_inputs();
147 auto create_queue = [batch_scheduler, queue_options](
148 std::function<void(std::unique_ptr<Batch<BatchingSessionTask>>)>
149 process_batch_callback,
150 std::unique_ptr<BatchScheduler<BatchingSessionTask>>* queue) {
151 TF_RETURN_IF_ERROR(batch_scheduler->AddQueue(
152 queue_options, process_batch_callback, queue));
153 return absl::OkStatus();
155 std::vector<SignatureWithBatchingSessionSchedulerCreator>
156 signatures_with_scheduler_creators;
157 for (
const SignatureDef& signature : signatures) {
158 const TensorSignature tensor_signature =
159 TensorSignatureFromSignatureDef(signature);
160 signatures_with_scheduler_creators.push_back(
161 {tensor_signature, create_queue});
164 return CreateBatchingSession(batching_session_options,
165 signatures_with_scheduler_creators, create_queue,
166 std::move(*session), session);
169 Status WrapSession(std::unique_ptr<Session>* session) {
170 session->reset(
new ServingSessionWrapper(std::move(*session)));
171 return absl::OkStatus();
174 Status WrapSessionIgnoreThreadPoolOptions(std::unique_ptr<Session>* session) {
176 new SessionWrapperIgnoreThreadPoolOptions(std::move(*session)));
177 return absl::OkStatus();