16 #include "tensorflow_serving/batching/batching_session.h"
24 #include <unordered_map>
28 #include "absl/container/fixed_array.h"
29 #include "tensorflow/core/framework/cost_graph.pb.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/kernels/batching_util/input_split_metadata.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/monitoring/counter.h"
39 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
40 #include "tensorflow/core/lib/monitoring/sampler.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/threadpool_options.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/profiler/lib/traceme_encode.h"
47 #include "tensorflow/core/public/session.h"
48 #include "tensorflow_serving/batching/batching_util.h"
49 #include "tensorflow_serving/batching/incremental_barrier.h"
50 #include "tensorflow_serving/batching/threadsafe_status.h"
51 #include "tensorflow_serving/servables/tensorflow/serving_session.h"
52 #include "tensorflow_serving/util/hash.h"
54 namespace tensorflow {
59 auto* queuing_latency = monitoring::Sampler<1>::New(
60 {
"/tensorflow/serving/batching_session/queuing_latency",
61 "Distribution of wall time spent (in microseconds) in queuing",
64 monitoring::Buckets::Exponential(100, 1.2, 52));
66 auto* wrapped_run_count = monitoring::Counter<0>::New(
67 "/tensorflow/serving/batching_session/wrapped_run_count",
68 "Total count of run calls on the wrapped session");
70 string TensorSignatureDebugString(
const TensorSignature& signature) {
71 return strings::StrCat(
"{input_tensors: <",
72 absl::StrJoin(signature.input_tensors,
", "),
73 ">, output_tensors: <",
74 absl::StrJoin(signature.output_tensors,
", "),
">}");
77 struct HashTensorSignature {
78 uint64_t operator()(
const TensorSignature& signature)
const {
79 uint64_t hash = 0xDECAFCAFFE ;
80 for (
const string& input_tensor : signature.input_tensors) {
81 hash = HashCombine(hash, std::hash<string>()(input_tensor));
83 for (
const string& output_tensor : signature.output_tensors) {
84 hash = HashCombine(hash, std::hash<string>()(output_tensor));
90 struct EqTensorSignature {
91 bool operator()(
const TensorSignature& lhs,
92 const TensorSignature& rhs)
const {
93 return lhs.input_tensors == rhs.input_tensors &&
94 lhs.output_tensors == rhs.output_tensors;
100 TensorSignature TensorSignatureFromRunArgs(
101 const std::vector<std::pair<string, Tensor>>& inputs,
102 const std::vector<string>& output_tensor_names) {
103 TensorSignature signature;
104 for (
const auto& entry : inputs) {
105 const string& tensor_name = entry.first;
106 signature.input_tensors.insert(tensor_name);
108 for (
const string& output_tensor_name : output_tensor_names) {
109 signature.output_tensors.insert(output_tensor_name);
115 const std::vector<std::pair<string, Tensor>>& GetTaskInput(
116 const BatchingSessionTask& batching_session_task) {
117 if (batching_session_task.is_partial) {
118 return *batching_session_task.owned_split_inputs;
120 return *batching_session_task.inputs;
125 std::vector<std::vector<std::pair<string, Tensor>>> GetTaskInputsVector(
126 const Batch<BatchingSessionTask>& batch) {
127 std::vector<std::vector<std::pair<string, Tensor>>> all_task_inputs;
128 all_task_inputs.reserve(batch.num_tasks());
129 for (
int i = 0; i < batch.num_tasks(); ++i) {
130 all_task_inputs.push_back(GetTaskInput(batch.task(i)));
132 return all_task_inputs;
137 TensorSignature TensorSignatureFromSignatureDef(
138 const SignatureDef& signature_def) {
139 return TensorSignatureFromSignatureDefs({signature_def});
142 TensorSignature TensorSignatureFromSignatureDefs(
143 const std::vector<SignatureDef>& signature_defs) {
144 TensorSignature tensor_signature;
145 for (
const SignatureDef& signature_def : signature_defs) {
146 for (
const auto& entry : signature_def.inputs()) {
147 const TensorInfo& tensor_info = entry.second;
148 tensor_signature.input_tensors.insert(tensor_info.name());
150 for (
const auto& entry : signature_def.outputs()) {
151 const TensorInfo& tensor_info = entry.second;
152 tensor_signature.output_tensors.insert(tensor_info.name());
155 return tensor_signature;
169 static Status Create(
171 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
172 signatures_with_scheduler_creators,
173 const std::string& thread_pool_name,
174 std::unique_ptr<BatchingSession>* result);
179 static Status Create(
181 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
182 signatures_with_scheduler_creators,
183 BatchingSessionSchedulerCreator default_creator,
184 const std::string& thread_pool_name,
185 std::unique_ptr<BatchingSession>* result);
189 Status Run(
const std::vector<std::pair<string, Tensor>>& inputs,
190 const std::vector<string>& output_tensor_names,
191 const std::vector<string>& target_node_names,
192 std::vector<Tensor>* outputs)
override;
209 Status Run(
const RunOptions& run_options,
210 const std::vector<std::pair<string, Tensor>>& inputs,
211 const std::vector<string>& output_tensor_names,
212 const std::vector<string>& target_node_names,
213 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override;
219 Status Run(
const RunOptions& run_options,
220 const std::vector<std::pair<string, Tensor>>& inputs,
221 const std::vector<string>& output_tensor_names,
222 const std::vector<string>& target_node_names,
223 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
224 const thread::ThreadPoolOptions& thread_pool_options)
override;
226 Status ListDevices(std::vector<DeviceAttributes>* response)
override;
230 const std::string& thread_pool_name);
234 const RunOptions& run_options,
235 const std::vector<std::pair<string, Tensor>>& inputs,
236 const std::vector<string>& output_tensor_names,
237 const std::vector<string>& target_node_names,
238 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
239 absl::optional<thread::ThreadPoolOptions> thread_pool_options);
245 Status ComputeInputSize(
const std::vector<std::pair<string, Tensor>>& inputs,
253 Status MergeInputTensors(
254 const TensorSignature& signature,
const Batch<BatchingSessionTask>& batch,
255 std::vector<std::pair<string, Tensor>>* merged_inputs);
260 const std::vector<Tensor>& combined_outputs,
261 Batch<BatchingSessionTask>* batch);
265 Status SplitRunMetadata(RunMetadata* batch_metadata,
266 Batch<BatchingSessionTask>* batch);
271 std::unique_ptr<Batch<BatchingSessionTask>> batch);
276 const std::string thread_pool_name_;
278 std::unique_ptr<Session> wrapped_;
280 std::unique_ptr<BatchScheduler<BatchingSessionTask>>,
281 HashTensorSignature, EqTensorSignature>
288 absl::optional<BatchingSessionSchedulerCreator> default_scheduler_creator_;
291 std::unique_ptr<BatchScheduler<BatchingSessionTask>>,
292 HashTensorSignature, EqTensorSignature>
293 custom_signature_batch_schedulers_ ABSL_GUARDED_BY(mu_);
298 Status BatchingSession::Create(
300 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
301 signatures_with_scheduler_creators,
302 BatchingSessionSchedulerCreator default_creator,
303 const std::string& thread_pool_name,
304 std::unique_ptr<BatchingSession>* result) {
305 auto status = BatchingSession::Create(options, std::move(wrapped),
306 signatures_with_scheduler_creators,
307 thread_pool_name, result);
308 result->get()->default_scheduler_creator_ = default_creator;
312 Status BatchingSession::Create(
313 const BatchingSessionOptions& options, std::unique_ptr<Session> wrapped,
314 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
315 signatures_with_scheduler_creators,
316 const std::string& thread_pool_name,
317 std::unique_ptr<BatchingSession>* result) {
318 auto batching_session = std::unique_ptr<BatchingSession>(
319 new BatchingSession(options, thread_pool_name));
320 BatchingSession* raw_batching_session = batching_session.get();
321 batching_session->wrapped_ = std::move(wrapped);
323 for (
const auto& entry : signatures_with_scheduler_creators) {
324 const TensorSignature& signature = entry.signature;
325 const BatchingSessionSchedulerCreator& scheduler_creator =
326 entry.scheduler_creator;
328 std::unique_ptr<BatchScheduler<BatchingSessionTask>> batch_scheduler;
329 TF_RETURN_IF_ERROR(scheduler_creator(
330 [signature, raw_batching_session](
331 std::unique_ptr<Batch<BatchingSessionTask>> batch) {
332 raw_batching_session->ProcessBatch(signature, std::move(batch));
335 batching_session->batch_schedulers_[signature] = std::move(batch_scheduler);
338 *result = std::move(batching_session);
339 return absl::OkStatus();
342 Status BatchingSession::Run(
343 const std::vector<std::pair<string, Tensor>>& inputs,
344 const std::vector<string>& output_tensor_names,
345 const std::vector<string>& target_node_names,
346 std::vector<Tensor>* outputs) {
347 RunMetadata run_metadata;
348 return Run(RunOptions(), inputs, output_tensor_names, target_node_names,
349 outputs, &run_metadata);
352 Status BatchingSession::Run(
353 const RunOptions& run_options,
354 const std::vector<std::pair<string, Tensor>>& inputs,
355 const std::vector<string>& output_tensor_names,
356 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
357 RunMetadata* run_metadata) {
358 return InternalRun(run_options, inputs, output_tensor_names,
359 target_node_names, outputs, run_metadata, absl::nullopt);
362 Status BatchingSession::Run(
363 const RunOptions& run_options,
364 const std::vector<std::pair<string, Tensor>>& inputs,
365 const std::vector<string>& output_tensor_names,
366 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
367 RunMetadata* run_metadata,
368 const thread::ThreadPoolOptions& thread_pool_options) {
369 return InternalRun(run_options, inputs, output_tensor_names,
370 target_node_names, outputs, run_metadata,
371 thread_pool_options);
374 Status BatchingSession::InternalRun(
375 const RunOptions& run_options,
376 const std::vector<std::pair<string, Tensor>>& inputs,
377 const std::vector<string>& output_tensor_names,
378 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
379 RunMetadata* run_metadata,
380 absl::optional<thread::ThreadPoolOptions> thread_pool_options) {
381 if (!target_node_names.empty()) {
382 return errors::PermissionDenied(
383 "BatchingSession does not support target nodes");
386 tsl::profiler::TraceMe trace_me([
this] {
387 return tsl::profiler::TraceMeEncode(
388 "BatchingSessionRun",
389 {{
"thread_pool_name", thread_pool_name_}, {
"_r", 1} });
391 const TensorSignature signature =
392 TensorSignatureFromRunArgs(inputs, output_tensor_names);
393 auto batch_scheduler_it = batch_schedulers_.find(signature);
394 if (batch_scheduler_it == batch_schedulers_.end()) {
395 if (default_scheduler_creator_.has_value()) {
396 absl::MutexLock l(&mu_);
397 batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
398 if (batch_scheduler_it == custom_signature_batch_schedulers_.end()) {
399 std::unique_ptr<BatchScheduler<BatchingSessionTask>> batch_scheduler;
400 TF_RETURN_IF_ERROR(default_scheduler_creator_.value()(
401 [&, signature](std::unique_ptr<Batch<BatchingSessionTask>> batch) {
402 ProcessBatch(signature, std::move(batch));
405 custom_signature_batch_schedulers_[signature] =
406 std::move(batch_scheduler);
407 batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
412 LOG_EVERY_N_SEC(WARNING, 120)
413 <<
"Request doesn't match any declared signature and no default "
414 "scheduler creator specified. Bypassing "
415 "batcher. Request signature is: "
416 << TensorSignatureDebugString(signature);
421 if (thread_pool_options) {
422 return wrapped_->Run(run_options, inputs, output_tensor_names,
423 target_node_names, outputs, run_metadata,
424 thread_pool_options.value());
426 return wrapped_->Run(run_options, inputs, output_tensor_names,
427 target_node_names, outputs, run_metadata);
431 BatchScheduler<BatchingSessionTask>* batch_scheduler =
432 batch_scheduler_it->second.get();
438 auto task = std::unique_ptr<BatchingSessionTask>(
new BatchingSessionTask);
439 task->enqueue_time_micros = EnvTime::NowMicros();
440 task->run_options = run_options;
441 TF_RETURN_IF_ERROR(ComputeInputSize(inputs, &task->zeroth_dim_size));
442 task->inputs = &inputs;
443 task->output_tensor_names = &output_tensor_names;
445 task->status = &status;
446 task->outputs = outputs;
447 task->run_metadata = run_metadata;
448 task->thread_pool_options = thread_pool_options;
449 task->thread_safe_status = std::make_shared<ThreadSafeStatus>();
450 task->shared_outputs = std::make_shared<std::vector<std::vector<Tensor>>>();
451 task->split_run_metadatas = absl::make_unique<std::vector<RunMetadata>>();
453 TF_RETURN_IF_ERROR(batch_scheduler->Schedule(&task));
454 done.WaitForNotification();
458 Status BatchingSession::ListDevices(std::vector<DeviceAttributes>* response) {
459 return wrapped_->ListDevices(response);
462 BatchingSession::BatchingSession(
const BatchingSessionOptions& options,
463 const std::string& thread_pool_name)
464 : options_(options), thread_pool_name_(thread_pool_name) {}
466 Status BatchingSession::ComputeInputSize(
467 const std::vector<std::pair<string, Tensor>>& inputs,
size_t* size)
const {
468 TF_RETURN_IF_ERROR(::tensorflow::serving::ComputeTensorBatchSize(
470 [](
const std::pair<std::string, Tensor>& tensor) {
471 return tensor.second.shape().dims();
473 [](
const std::pair<std::string, Tensor>& tensor,
size_t dim) {
474 return tensor.second.shape().dim_size(dim);
476 for (
const auto& entry : inputs) {
477 const Tensor& tensor = entry.second;
478 RecordInputBatchSize<BatchingSessionTask>(tensor.shape().dim_size(0));
480 return absl::OkStatus();
483 Status BatchingSession::MergeInputTensors(
484 const TensorSignature& signature,
const Batch<BatchingSessionTask>& batch,
485 std::vector<std::pair<string, Tensor>>* merged_inputs) {
486 DCHECK_GE(batch.num_tasks(), 1);
487 if (batch.num_tasks() < 1) {
488 return errors::Internal(
"Batch size expected to be positive; was ",
492 const int lowest_allowed_batch_size =
493 RoundToLowestAllowedBatchSize(options_.allowed_batch_sizes, batch.size());
494 const int padding_size = lowest_allowed_batch_size - batch.size();
495 tsl::profiler::TraceMe trace_me([lowest_allowed_batch_size, padding_size]() {
496 return tsl::profiler::TraceMeEncode(
498 {{
"batch_size_after_padding", lowest_allowed_batch_size},
499 {
"padding_amount", padding_size}});
501 RecordPaddingSize<BatchingSessionTask>(padding_size,
502 lowest_allowed_batch_size);
503 RecordProcessedBatchSize<BatchingSessionTask>(lowest_allowed_batch_size);
506 std::map<string, std::vector<Tensor>> tensors_to_merge;
509 absl::optional<std::map<string, std::vector<int>>> max_dim_sizes;
510 if (options_.pad_variable_length_inputs) {
511 std::vector<std::vector<std::pair<string, Tensor>>> all_task_inputs =
512 GetTaskInputsVector(batch);
513 max_dim_sizes = CalculateMaxDimSizes(all_task_inputs);
516 for (
int i = 0; i < batch.num_tasks(); ++i) {
517 const std::vector<std::pair<string, Tensor>>& task_inputs =
518 GetTaskInput(batch.task(i));
519 for (
const auto& entry : task_inputs) {
520 const string& tensor_name = entry.first;
521 const Tensor& tensor = entry.second;
523 std::vector<Tensor>& tensor_vec = tensors_to_merge[tensor_name];
524 Tensor optionally_padded_tensor;
525 if (options_.pad_variable_length_inputs) {
526 TF_RETURN_IF_ERROR(AddPadding(tensor, (*max_dim_sizes)[tensor_name],
527 &optionally_padded_tensor));
529 optionally_padded_tensor = tensor;
533 TensorShape reference_shape =
534 tensors_to_merge[tensor_name][0].shape();
535 if (!AreShapesEqualExceptZeroDim(tensor.shape(), reference_shape)) {
536 return errors::FailedPrecondition(
537 "Tensors with name '" + tensor_name +
538 "' from different tasks have different shapes and padding is "
539 "turned off. Set pad_variable_length_inputs to true, or ensure "
540 "that all tensors with the same name have equal dimensions "
541 "starting with the first dim.");
545 tensor_vec.push_back(std::move(optionally_padded_tensor));
546 if (i == batch.num_tasks() - 1 && padding_size > 0) {
555 const Tensor padding_tensor = tensor_vec.back().Slice(0, 1);
556 for (
int i = 0; i < padding_size; ++i) {
557 tensor_vec.push_back(padding_tensor);
564 DCHECK_EQ(signature.input_tensors.size(), tensors_to_merge.size());
565 if (tensors_to_merge.size() != signature.input_tensors.size()) {
566 return errors::Internal(
567 "One or more tasks does not conform to batch signature");
569 for (
const string& tensor_name : signature.input_tensors) {
570 auto tensors = tensors_to_merge.find(tensor_name);
571 DCHECK(tensors != tensors_to_merge.end());
572 if (tensors == tensors_to_merge.end()) {
573 return errors::Internal(
574 "One or more tasks does not conform to batch signature");
577 const Status concat_status = tensor::Concat(tensors->second, &concated);
578 DCHECK(concat_status.ok()) << concat_status.ToString();
579 if (!concat_status.ok()) {
580 return errors::Internal(
"Tensor concat operation failed: ",
581 concat_status.ToString());
583 merged_inputs->push_back({tensor_name, std::move(concated)});
586 return absl::OkStatus();
589 Status BatchingSession::SplitOutputTensors(
590 const TensorSignature& signature,
591 const std::vector<Tensor>& combined_outputs,
592 Batch<BatchingSessionTask>* batch) {
593 DCHECK_GE(batch->num_tasks(), 1);
594 if (batch->num_tasks() < 1) {
595 return errors::Internal(
"Batch size expected to be positive; was ",
599 std::vector<int64_t> task_sizes_plus_optional_padding;
600 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
601 for (
int i = 0; i < batch->num_tasks(); ++i) {
602 task_sizes_plus_optional_padding.push_back(batch->task(i).zeroth_dim_size);
604 const int padding_size = RoundToLowestAllowedBatchSize(
605 options_.allowed_batch_sizes, batch->size()) -
607 if (padding_size > 0) {
608 task_sizes_plus_optional_padding.push_back(padding_size);
612 std::map<string, std::vector<Tensor>> split_tensors;
615 DCHECK_EQ(signature.output_tensors.size(), combined_outputs.size());
616 if (combined_outputs.size() != signature.output_tensors.size()) {
617 return errors::Internal(
"Wrong number of batched output tensors");
619 const std::vector<string> output_tensors(signature.output_tensors.begin(),
620 signature.output_tensors.end());
621 for (
int i = 0; i < output_tensors.size(); ++i) {
622 const string& tensor_name = output_tensors[i];
623 const Tensor& tensor = combined_outputs[i];
625 if (tensor.shape().dims() == 0) {
626 return errors::FailedPrecondition(
627 "Batched output tensor has 0 dimensions");
629 if (tensor.shape().dim_size(0) != batch->size() + padding_size) {
630 return errors::FailedPrecondition(
631 "Batched output tensor's 0th dimension does not equal the sum of the "
632 "0th dimension sizes of the input tensors");
635 std::vector<Tensor> split_tensor;
636 const Status split_status =
637 tensor::Split(tensor, task_sizes_plus_optional_padding, &split_tensor);
638 DCHECK(split_status.ok()) << split_status.ToString();
639 if (!split_status.ok()) {
640 return errors::Internal(
"Tensor split operation failed: ",
641 split_status.ToString());
643 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
644 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
645 return errors::Internal(
646 "Tensor split operation did not work as expected; got ",
647 split_tensor.size(),
" splits; expected ",
648 task_sizes_plus_optional_padding.size());
650 split_tensors[tensor_name] = std::move(split_tensor);
653 for (
int i = 0; i < batch->num_tasks(); ++i) {
654 BatchingSessionTask* task = batch->mutable_task(i);
655 for (
const string& tensor_name : *task->output_tensor_names) {
656 auto split_tensor = split_tensors.find(tensor_name);
657 DCHECK(split_tensor != split_tensors.end());
658 if (split_tensor == split_tensors.end()) {
659 return errors::Internal(
"Task does not conform to batch signature");
662 if (task->is_partial) {
663 std::vector<Tensor>& tensor_vector =
664 (*task->shared_outputs)[task->split_index];
665 tensor_vector.push_back(std::move(split_tensor->second[i]));
667 task->outputs->push_back(std::move(split_tensor->second[i]));
673 return absl::OkStatus();
676 Status BatchingSession::SplitRunMetadata(RunMetadata* batch_metadata,
677 Batch<BatchingSessionTask>* batch) {
678 if (batch->num_tasks() > 0) {
679 if (batch_metadata->has_cost_graph()) {
682 for (
size_t i = 0; i < batch_metadata->cost_graph().cost_size(); ++i) {
683 CostGraphDef_AggregatedCost* cost =
684 batch_metadata->mutable_cost_graph()->mutable_cost(i);
685 const float agg_cost = cost->cost();
686 cost->set_cost(agg_cost /
static_cast<float>(batch->num_tasks()));
690 for (
size_t i = 0; i < batch->num_tasks(); ++i) {
691 BatchingSessionTask* batching_session_task = batch->mutable_task(i);
692 if (batching_session_task->is_partial) {
696 (*batching_session_task
697 ->split_run_metadatas)[batching_session_task->split_index] =
700 RunMetadata* run_metadata = batching_session_task->run_metadata;
701 if (run_metadata !=
nullptr) {
702 *run_metadata = *batch_metadata;
708 return absl::OkStatus();
711 void BatchingSession::ProcessBatch(
712 const TensorSignature& signature,
713 std::unique_ptr<Batch<BatchingSessionTask>> batch) {
717 batch->WaitUntilClosed();
719 if (batch->empty()) {
723 const uint64_t dequeue_time_micros = EnvTime::NowMicros();
729 auto finally = gtl::MakeCleanup([&status, &batch] {
730 for (
int i = 0; i < batch->num_tasks(); ++i) {
731 BatchingSessionTask* task = batch->mutable_task(i);
732 if (task->is_partial) {
733 task->thread_safe_status->Update(status);
734 task->done_callback();
736 *batch->mutable_task(i)->status = status;
737 batch->mutable_task(i)->done->Notify();
745 bool all_tasks_timeout_exceeded =
true;
746 uint64_t batch_deadline_micros = 0;
747 for (
int i = 0; i < batch->num_tasks(); ++i) {
748 const BatchingSessionTask& task = batch->task(i);
751 const int64_t task_timeout_micros =
752 task.run_options.timeout_in_ms() <= 0
754 : task.run_options.timeout_in_ms() * 1000;
755 const uint64_t task_deadline_micros =
756 task.enqueue_time_micros + task_timeout_micros;
757 if (task_deadline_micros > dequeue_time_micros) {
758 all_tasks_timeout_exceeded =
false;
759 if (task_deadline_micros > batch_deadline_micros) {
760 batch_deadline_micros = task_deadline_micros;
763 queuing_latency->GetCell(thread_pool_name_)
764 ->Add(dequeue_time_micros - task.enqueue_time_micros);
766 if (all_tasks_timeout_exceeded) {
767 status = Status(
static_cast<tensorflow::errors::Code
>(
768 absl::StatusCode::kResourceExhausted),
769 "Run() timeout exceeded while waiting in batching queue");
773 RunOptions run_options = batch->task(0).run_options;
774 if (batch_deadline_micros == INT_MAX) {
775 run_options.set_timeout_in_ms(0);
777 run_options.set_timeout_in_ms(
778 (batch_deadline_micros - dequeue_time_micros) / 1000);
781 std::vector<std::pair<string, Tensor>> merged_inputs;
782 status = MergeInputTensors(signature, *batch, &merged_inputs);
787 absl::optional<thread::ThreadPoolOptions> thread_pool_options =
788 batch->task(0).thread_pool_options;
790 const std::vector<string> output_tensor_names(
791 signature.output_tensors.begin(), signature.output_tensors.end());
792 std::vector<Tensor> combined_outputs;
793 RunMetadata run_metadata;
797 if (thread_pool_options) {
798 status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
799 {} , &combined_outputs,
800 &run_metadata, thread_pool_options.value());
802 status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
803 {} , &combined_outputs,
806 wrapped_run_count->GetCell()->IncrementBy(1);
807 status.Update(SplitRunMetadata(&run_metadata, batch.get()));
813 status = SplitOutputTensors(signature, combined_outputs, batch.get());
820 Status SplitInputTask(
821 std::unique_ptr<BatchingSessionTask>* input_task_ptr,
822 int open_batch_remaining_slot,
int max_batch_size,
823 std::vector<std::unique_ptr<BatchingSessionTask>>* output_tasks) {
824 BatchingSessionTask& input_task = *(*input_task_ptr);
825 const int64_t input_task_size = input_task.size();
827 DCHECK_GT(input_task_size, 0);
830 std::function<void()> split_task_done_callback =
831 [done_notification = input_task.done,
832 shared_outputs = input_task.shared_outputs,
833 shared_status = input_task.thread_safe_status,
834 num_output = input_task.output_tensor_names->size(),
835 outputs = input_task.outputs, status = input_task.status,
836 run_metadata = input_task.run_metadata,
837 split_run_metadatas = input_task.split_run_metadatas]() {
838 auto finally = gtl::MakeCleanup([&] {
839 *status = shared_status->status();
840 done_notification->Notify();
845 if (!shared_status->status().ok()) {
849 for (
int i = 0; i < num_output; ++i) {
850 Tensor output_tensor;
853 std::vector<Tensor> to_concatenate;
854 to_concatenate.reserve(shared_outputs->size());
855 for (
int j = 0; j < shared_outputs->size(); ++j) {
856 to_concatenate.push_back(std::move((*shared_outputs)[j][i]));
858 const auto concat_status =
859 tensor::Concat(to_concatenate, &output_tensor);
860 if (!concat_status.ok()) {
861 shared_status->Update(concat_status);
865 outputs->push_back(std::move(output_tensor));
870 absl::flat_hash_map<string, float> cost_dimension_map;
871 for (
const auto& split : *split_run_metadatas) {
872 if (split.has_cost_graph()) {
873 for (
const auto& cost : split.cost_graph().cost()) {
874 cost_dimension_map[cost.dimension()] += cost.cost();
879 *run_metadata = (*split_run_metadatas)[0];
880 std::vector<string> cost_dimensions;
881 for (
const auto& cost_and_dimension :
882 run_metadata->cost_graph().cost()) {
883 cost_dimensions.push_back(cost_and_dimension.dimension());
885 run_metadata->mutable_cost_graph()->clear_cost();
886 for (
const auto& dimension : cost_dimensions) {
887 const auto iter = cost_dimension_map.find(dimension);
888 if (iter != cost_dimension_map.end()) {
889 auto graph_cost = run_metadata->mutable_cost_graph()->add_cost();
890 graph_cost->set_dimension(iter->first);
891 graph_cost->set_cost(iter->second);
895 IncrementalBarrier barrier(split_task_done_callback);
897 const internal::InputSplitMetadata input_split_metadata(
898 input_task_size, open_batch_remaining_slot, max_batch_size);
902 const absl::FixedArray<int64_t> output_task_sizes(
903 input_split_metadata.task_sizes().begin(),
904 input_split_metadata.task_sizes().end());
905 const int num_batches = output_task_sizes.size();
907 input_task.shared_outputs->resize(num_batches);
909 for (
int i = 0; i < num_batches; ++i) {
910 (*input_task.shared_outputs)[i].reserve(
911 input_task.output_tensor_names->size());
914 input_task.split_run_metadatas->resize(num_batches);
916 output_tasks->reserve(num_batches);
917 for (
int i = 0; i < num_batches; i++) {
918 auto task = absl::make_unique<BatchingSessionTask>();
919 task->enqueue_time_micros = input_task.enqueue_time_micros;
920 task->run_options = input_task.run_options;
921 task->zeroth_dim_size = output_task_sizes[i];
923 task->output_tensor_names = input_task.output_tensor_names;
925 task->owned_split_inputs =
926 absl::make_unique<std::vector<std::pair<string, Tensor>>>();
927 task->split_index = i;
928 task->shared_outputs = input_task.shared_outputs;
929 task->thread_safe_status = input_task.thread_safe_status;
930 task->is_partial =
true;
931 task->done_callback = barrier.Inc();
932 task->thread_pool_options = input_task.thread_pool_options;
934 task->split_run_metadatas = input_task.split_run_metadatas;
936 output_tasks->push_back(std::move(task));
939 const int num_input_tensors = input_task.inputs->size();
943 for (
int i = 0; i < num_input_tensors; ++i) {
944 std::vector<Tensor> split_tensors;
945 const string& tensor_name = (*input_task.inputs)[i].first;
946 const Tensor& input_tensor = (*input_task.inputs)[i].second;
950 const Status split_status =
951 tensor::Split(input_tensor, output_task_sizes, &split_tensors);
952 if (!split_status.ok()) {
953 return errors::Internal(
954 "When splitting input, Tensor split operation failed: ",
955 split_status.ToString());
957 if (split_tensors.size() != output_task_sizes.size()) {
958 return errors::Internal(
959 "When splitting input, tensor split operation did not work as "
961 split_tensors.size(),
" splits; expected ", output_task_sizes.size());
963 for (
int j = 0; j < output_tasks->size(); ++j) {
964 BatchingSessionTask& output_task = *((*output_tasks)[j]);
965 output_task.owned_split_inputs->push_back(
966 std::make_pair(tensor_name, split_tensors[j]));
969 return absl::OkStatus();
972 Status CreateBatchingSession(
973 const BatchingSessionOptions& options,
974 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
975 signatures_with_scheduler_creators,
976 BatchingSessionSchedulerCreator default_creator,
977 std::unique_ptr<Session> session,
978 std::unique_ptr<Session>* batching_session) {
979 std::unique_ptr<BatchingSession> internal_batching_session;
980 TF_RETURN_IF_ERROR(BatchingSession::Create(
981 options, std::move(session), signatures_with_scheduler_creators,
982 default_creator,
"", &internal_batching_session));
983 *batching_session = std::move(internal_batching_session);
984 return absl::OkStatus();
987 Status CreateBatchingSession(
988 const BatchingSessionOptions& options,
989 const std::vector<SignatureWithBatchingSessionSchedulerCreator>&
990 signatures_with_scheduler_creators,
991 std::unique_ptr<Session> session,
992 std::unique_ptr<Session>* batching_session) {
993 std::unique_ptr<BatchingSession> internal_batching_session;
994 TF_RETURN_IF_ERROR(BatchingSession::Create(
995 options, std::move(session), signatures_with_scheduler_creators,
996 "", &internal_batching_session));
997 *batching_session = std::move(internal_batching_session);
998 return absl::OkStatus();
1001 Status CreateBasicBatchingSession(
1002 const BasicBatchScheduler<BatchingSessionTask>::Options& schedule_options,
1003 const BatchingSessionOptions& batching_session_options,
1004 const TensorSignature& signature, std::unique_ptr<Session> session,
1005 std::unique_ptr<Session>* batching_session) {
1006 const auto& allowed_batch_sizes =
1007 batching_session_options.allowed_batch_sizes;
1008 if (!allowed_batch_sizes.empty()) {
1009 if (schedule_options.enable_large_batch_splitting) {
1010 const int max_allowed_batch_size = allowed_batch_sizes.back();
1011 int32 last_size = 0;
1012 for (
size_t i = 0; i < allowed_batch_sizes.size(); ++i) {
1013 const int32 size = allowed_batch_sizes.at(i);
1014 if (i > 0 && size <= last_size) {
1015 return errors::InvalidArgument(
1016 "allowed_batch_sizes entries must be monotonically increasing");
1020 if (max_allowed_batch_size > schedule_options.max_batch_size) {
1021 return errors::InvalidArgument(
1022 "Last entry in allowed_batch_sizes must be less than or equal to "
1023 "max_batch_size; last "
1025 max_allowed_batch_size,
"; expected ",
1026 schedule_options.max_batch_size);
1028 if (schedule_options.max_execution_batch_size != max_allowed_batch_size) {
1029 return errors::InvalidArgument(
1030 "Last entry in allowed_batch_sizes must be equal to "
1031 "max_execution_batch_size; last "
1033 max_allowed_batch_size,
"; expected ",
1034 schedule_options.max_execution_batch_size);
1036 }
else if (allowed_batch_sizes.back() != schedule_options.max_batch_size) {
1040 return errors::InvalidArgument(
1041 "Last entry in allowed_batch_sizes must match max_batch_size; last "
1043 batching_session_options.allowed_batch_sizes.back(),
"; expected ",
1044 schedule_options.max_batch_size);
1048 auto scheduler_creator =
1050 std::function<void(std::unique_ptr<Batch<BatchingSessionTask>>)>
1051 process_batch_callback,
1052 std::unique_ptr<BatchScheduler<BatchingSessionTask>>*
1054 std::unique_ptr<BasicBatchScheduler<BatchingSessionTask>>
1055 basic_batch_scheduler;
1056 TF_RETURN_IF_ERROR(BasicBatchScheduler<BatchingSessionTask>::Create(
1057 schedule_options, process_batch_callback, &basic_batch_scheduler));
1058 *batch_scheduler = std::move(basic_batch_scheduler);
1059 return absl::OkStatus();
1062 std::unique_ptr<BatchingSession> internal_batching_session;
1063 TF_RETURN_IF_ERROR(BatchingSession::Create(
1064 batching_session_options, std::move(session),
1065 {{signature, scheduler_creator}}, schedule_options.thread_pool_name,
1066 &internal_batching_session));
1067 *batching_session = std::move(internal_batching_session);
1068 return absl::OkStatus();