16 #ifndef TENSORFLOW_SERVING_BATCHING_BATCHING_UTIL_H_
17 #define TENSORFLOW_SERVING_BATCHING_BATCHING_UTIL_H_
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/monitoring/sampler.h"
28 namespace tensorflow {
43 std::map<string, std::vector<int>> CalculateMaxDimSizes(
44 const std::vector<std::vector<std::pair<string, Tensor>>>& batch);
64 Status AddPadding(
const Tensor& tensor, absl::Span<const int> max_dim_sizes,
65 Tensor* padded_tensor);
70 int RoundToLowestAllowedBatchSize(absl::Span<const int> allowed_batch_sizes,
76 bool AreShapesEqualExceptZeroDim(
const TensorShape& shape1,
77 const TensorShape& shape2);
81 template <
typename TensorList,
typename DimFunc,
typename DimSizeFunc>
82 Status ComputeTensorBatchSize(TensorList inputs,
size_t* size, DimFunc dim_func,
83 DimSizeFunc dim_size_func) {
85 return errors::InvalidArgument(
86 "Batching Run() must have at least one input tensor");
90 for (
const auto& tensor : inputs) {
91 if (dim_func(tensor) == 0) {
92 return errors::InvalidArgument(
93 "Batching Run() input tensors must have at least one "
96 const size_t this_size = dim_size_func(tensor, 0);
102 if (this_size != *size) {
103 return errors::InvalidArgument(
104 "Batching Run() input tensors must have equal "
105 "0th-dimension size");
117 template <
typename BatchingTask>
118 void RecordPaddingSize(int32 padding_size, int32 execution_batch_size) {
119 static const std::string batching_task_name = BatchingTask::Name();
120 static auto* cell = tensorflow::monitoring::Sampler<1>::New(
121 {absl::StrCat(
"/tensorflow/serving/", batching_task_name,
123 "Tracks the padding size distribution on batches.",
124 "execution_batch_size"},
126 monitoring::Buckets::Exponential(1, 2, 14));
127 cell->GetCell(absl::StrCat(execution_batch_size))
128 ->Add(
static_cast<double>(padding_size));
131 template <
typename BatchingTask>
132 void RecordInputBatchSize(int32 batch_size) {
133 static const std::string batching_task_name = BatchingTask::Name();
134 static auto* cell = tensorflow::monitoring::Sampler<0>::New(
135 {absl::StrCat(
"/tensorflow/serving/", batching_task_name,
136 "/input_batch_size"),
137 "Tracks the batch size distribution on the inputs."},
139 monitoring::Buckets::Exponential(1, 2, 14));
140 cell->GetCell()->Add(
static_cast<double>(batch_size));
143 template <
typename BatchingTask>
144 void RecordProcessedBatchSize(int32 batch_size) {
145 static const std::string batching_task_name = BatchingTask::Name();
146 static auto* cell = tensorflow::monitoring::Sampler<0>::New(
147 {absl::StrCat(
"/tensorflow/serving/", batching_task_name,
148 "/processed_batch_size"),
149 "Tracks the batch size distribution on processing."},
151 monitoring::Buckets::Exponential(1, 2, 14));
152 cell->GetCell()->Add(
static_cast<double>(batch_size));