19 #ifndef TENSORFLOW_SERVING_BATCHING_BATCHING_SESSION_H_
20 #define TENSORFLOW_SERVING_BATCHING_BATCHING_SESSION_H_
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
31 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
32 #include "tensorflow/core/platform/threadpool_options.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #include "tensorflow/core/protobuf/meta_graph.pb.h"
35 #include "tensorflow/core/public/session.h"
36 #include "tensorflow_serving/batching/batching_options.h"
37 #include "tensorflow_serving/batching/threadsafe_status.h"
39 namespace tensorflow {
44 struct BatchingSessionTask;
48 using BatchingSessionSchedulerCreator = std::function<Status(
49 std::function<
void(std::unique_ptr<Batch<BatchingSessionTask>>)>,
50 std::unique_ptr<BatchScheduler<BatchingSessionTask>>*)>;
56 std::set<string> input_tensors;
57 std::set<string> output_tensors;
62 const SignatureDef& signature_def);
75 const std::vector<SignatureDef>& signature_defs);
81 BatchingSessionSchedulerCreator scheduler_creator;
129 Status CreateBatchingSession(
131 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
132 signatures_with_scheduler_creators,
133 std::unique_ptr<Session> session,
134 std::unique_ptr<Session>* batching_session);
138 Status CreateBatchingSession(
140 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
141 signatures_with_scheduler_creators,
142 BatchingSessionSchedulerCreator default_creator,
143 std::unique_ptr<Session> session,
144 std::unique_ptr<Session>* batching_session);
148 Status CreateBasicBatchingSession(
149 const typename BasicBatchScheduler<BatchingSessionTask>::Options&
153 std::unique_ptr<Session>* batching_session);
159 Status SplitInputTask(
160 std::unique_ptr<BatchingSessionTask>* input_task_ptr,
161 int open_batch_remaining_slot,
int max_batch_size,
162 std::vector<std::unique_ptr<BatchingSessionTask>>* output_tasks);
169 size_t size()
const override {
return zeroth_dim_size; }
172 static std::string Name() {
return "batching_session"; }
175 uint64_t enqueue_time_micros;
176 RunOptions run_options;
177 size_t zeroth_dim_size;
178 const std::vector<std::pair<string, Tensor>>* inputs;
179 const std::vector<string>* output_tensor_names;
185 std::vector<Tensor>* outputs;
186 RunMetadata* run_metadata;
187 absl::optional<thread::ThreadPoolOptions> thread_pool_options;
192 bool is_partial =
false;
195 std::unique_ptr<std::vector<std::pair<string, Tensor>>> owned_split_inputs;
199 std::function<void()> done_callback;
200 typedef std::vector<std::vector<Tensor>> TensorMatrix;
207 std::shared_ptr<TensorMatrix> shared_outputs;
210 std::shared_ptr<ThreadSafeStatus> thread_safe_status;
212 std::shared_ptr<std::vector<RunMetadata>> split_run_metadatas;