TensorFlow Serving C++ API Documentation
bundle_factory_util.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "google/protobuf/wrappers.pb.h"
24 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow_serving/batching/batching_session.h"
30 #include "tensorflow_serving/resources/resource_values.h"
31 #include "tensorflow_serving/servables/tensorflow/serving_session.h"
32 #include "tensorflow_serving/util/proto_util.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 
37 namespace {
38 
39 using Batcher = SharedBatchScheduler<BatchingSessionTask>;
40 
41 const char kBatchingParamsFilename[] = "batching_params.pbtxt";
42 
43 bool BatchingParamsFound(const string& model_dir) {
44  const string& fname = io::JoinPath(model_dir, kBatchingParamsFilename);
45  return Env::Default()->FilesExist({fname}, nullptr);
46 }
47 
48 } // namespace
49 
50 SessionOptions GetSessionOptions(const SessionBundleConfig& config) {
51  SessionOptions options;
52  options.target = config.session_target();
53  options.config = config.session_config();
54  return options;
55 }
56 
57 RunOptions GetRunOptions(const SessionBundleConfig& config) {
58  RunOptions run_options;
59  if (config.has_session_run_load_threadpool_index()) {
60  run_options.set_inter_op_thread_pool(
61  config.session_run_load_threadpool_index().value());
62  }
63  return run_options;
64 }
65 
66 Status GetPerModelBatchingParams(const string& path,
67  const BatchingParameters& common_params,
68  bool per_model_configured,
69  absl::optional<BatchingParameters>* params) {
70  if (per_model_configured) {
71  if (BatchingParamsFound(path)) {
72  *params = absl::make_optional(BatchingParameters());
73  TF_RETURN_IF_ERROR(ParseProtoTextFile(
74  io::JoinPath(path, kBatchingParamsFilename), &params->value()));
75  VLOG(1) << "Wrapping session to perform batch processing "
76  << "using SavedModel batching params: "
77  << params->value().DebugString();
78  }
79  } else {
80  *params = absl::make_optional(common_params);
81  VLOG(1) << "Wrapping session to perform batch processing "
82  << "using session config batching params: "
83  << params->value().DebugString();
84  }
85  return absl::OkStatus();
86 }
87 
88 Status EstimateResourceFromValidationResult(const string& path,
89  ResourceAllocation* estimate) {
90  return EstimateMainRamBytesFromValidationResult(path, estimate);
91 }
92 
93 Status EstimateResourceFromPath(const string& path, bool use_validation_result,
94  ResourceAllocation* estimate) {
95  TensorflowFileProbingEnv env(Env::Default());
96  return EstimateMainRamBytesFromPath(path, use_validation_result, &env,
97  estimate);
98 }
99 
100 Status WrapSessionForBatching(const BatchingParameters& batching_config,
101  std::shared_ptr<Batcher> batch_scheduler,
102  const std::vector<SignatureDef>& signatures,
103  std::unique_ptr<Session>* session) {
104  LOG(INFO) << "Wrapping session to perform batch processing";
105 
106  if (batch_scheduler == nullptr) {
107  return errors::Internal("batch_scheduler not set");
108  }
109  if (*session == nullptr) {
110  return errors::Internal("session not set");
111  }
112 
113  if (!batching_config.allowed_batch_sizes().empty()) {
114  // Verify that the last allowed batch size matches the max batch size.
115  const int last_allowed_size = batching_config.allowed_batch_sizes(
116  batching_config.allowed_batch_sizes().size() - 1);
117  const int max_size = batching_config.has_max_batch_size()
118  ? batching_config.max_batch_size().value()
119  : Batcher::QueueOptions().input_batch_size_limit;
120  if (last_allowed_size != max_size) {
121  return errors::InvalidArgument(
122  "Last entry in allowed_batch_sizes must match max_batch_size; last "
123  "entry was ",
124  last_allowed_size, "; expected ", max_size);
125  }
126  }
127 
128  auto queue_options = GetQueueOptions<
130  batching_config,
131  [](std::unique_ptr<tensorflow::serving::BatchingSessionTask>* input_task,
132  int open_batch_remaining_slot, int max_batch_size,
133  std::vector<std::unique_ptr<tensorflow::serving::BatchingSessionTask>>*
134  output_tasks) -> tensorflow::Status {
135  return SplitInputTask(input_task, open_batch_remaining_slot,
136  max_batch_size, output_tasks);
137  });
138 
139  BatchingSessionOptions batching_session_options;
140  for (int allowed_batch_size : batching_config.allowed_batch_sizes()) {
141  batching_session_options.allowed_batch_sizes.push_back(allowed_batch_size);
142  }
143 
144  batching_session_options.pad_variable_length_inputs =
145  batching_config.pad_variable_length_inputs();
146 
147  auto create_queue = [batch_scheduler, queue_options](
148  std::function<void(std::unique_ptr<Batch<BatchingSessionTask>>)>
149  process_batch_callback,
150  std::unique_ptr<BatchScheduler<BatchingSessionTask>>* queue) {
151  TF_RETURN_IF_ERROR(batch_scheduler->AddQueue(
152  queue_options, process_batch_callback, queue));
153  return absl::OkStatus();
154  };
155  std::vector<SignatureWithBatchingSessionSchedulerCreator>
156  signatures_with_scheduler_creators;
157  for (const SignatureDef& signature : signatures) {
158  const TensorSignature tensor_signature =
159  TensorSignatureFromSignatureDef(signature);
160  signatures_with_scheduler_creators.push_back(
161  {tensor_signature, create_queue});
162  }
163 
164  return CreateBatchingSession(batching_session_options,
165  signatures_with_scheduler_creators, create_queue,
166  std::move(*session), session);
167 }
168 
169 Status WrapSession(std::unique_ptr<Session>* session) {
170  session->reset(new ServingSessionWrapper(std::move(*session)));
171  return absl::OkStatus();
172 }
173 
174 Status WrapSessionIgnoreThreadPoolOptions(std::unique_ptr<Session>* session) {
175  session->reset(
176  new SessionWrapperIgnoreThreadPoolOptions(std::move(*session)));
177  return absl::OkStatus();
178 }
179 
180 } // namespace serving
181 } // namespace tensorflow