16 #include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/monitoring/sampler.h"
24 #include "tensorflow_serving/batching/batching_util.h"
25 #include "tensorflow_serving/batching/incremental_barrier.h"
27 namespace tensorflow {
30 using tfrt::FunctionMetadata;
31 using tfrt::SavedModel;
35 auto *queuing_latency = monitoring::Sampler<0>::New(
36 {
"/tensorflow/serving/saved_model_with_batching/queuing_latency",
37 "Distribution of wall time spent (in microseconds) in queuing"},
39 monitoring::Buckets::Exponential(100, 1.2, 52));
42 class SavedModelWithBatching :
public tfrt::SavedModel {
44 static Status Create(
const SavedModelBatchingOptions &options,
45 const std::vector<FuncNameWithBatchingSchedulerCreator>
46 &func_name_with_batching_scheduler_creator,
47 std::unique_ptr<SavedModel> saved_model,
48 std::unique_ptr<SavedModel> *result);
52 SavedModelWithBatching(
const SavedModelBatchingOptions &options,
53 std::unique_ptr<SavedModel> saved_model);
55 const tensorflow::MetaGraphDef &GetMetaGraphDef()
const override {
56 return wrapped_->GetMetaGraphDef();
59 std::vector<std::string> GetFunctionNames()
const override {
60 return wrapped_->GetFunctionNames();
63 absl::optional<FunctionMetadata> GetFunctionMetadata(
64 absl::string_view func_name)
const override {
65 return wrapped_->GetFunctionMetadata(func_name);
71 Status Run(
const tfrt::SavedModel::RunOptions &run_options,
72 absl::string_view func_name, absl::Span<const Tensor> inputs,
73 std::vector<Tensor> *outputs)
override;
75 Status RunMultipleSignatures(
76 const RunOptions &run_options, absl::Span<const std::string> names,
77 absl::Span<
const std::vector<tensorflow::Tensor>> multi_inputs,
78 std::vector<std::vector<tensorflow::Tensor>> *multi_outputs)
override {
81 return wrapped_->RunMultipleSignatures(run_options, names, multi_inputs,
85 Status RunByTensorNames(
86 const RunOptions &run_options,
87 absl::Span<
const std::pair<std::string, tensorflow::Tensor>> inputs,
88 absl::Span<const std::string> output_tensor_names,
89 absl::Span<const std::string> target_node_names,
90 std::vector<tensorflow::Tensor> *outputs) {
92 return wrapped_->RunByTensorNames(run_options, inputs, output_tensor_names,
93 target_node_names, outputs);
98 void ProcessBatch(absl::string_view func_name,
99 std::unique_ptr<Batch<SavedModelBatchingTask>> batch);
102 Status BatchInputTensors(absl::string_view func_name,
103 const Batch<SavedModelBatchingTask> &batch,
104 std::vector<Tensor> *batch_inputs);
108 Status SplitOutputTensors(std::vector<Tensor> combined_outputs,
109 Batch<SavedModelBatchingTask> *batch);
111 const SavedModelBatchingOptions options_;
114 std::unique_ptr<SavedModel> wrapped_;
115 absl::flat_hash_map<std::string,
116 std::unique_ptr<BatchScheduler<SavedModelBatchingTask>>>
119 TF_DISALLOW_COPY_AND_ASSIGN(SavedModelWithBatching);
122 SavedModelWithBatching::SavedModelWithBatching(
123 const SavedModelBatchingOptions &options,
124 std::unique_ptr<SavedModel> saved_model)
125 : tfrt::SavedModel(&saved_model->runtime()),
127 wrapped_(std::move(saved_model)) {}
129 Status SavedModelWithBatching::Create(
130 const SavedModelBatchingOptions &options,
131 const std::vector<FuncNameWithBatchingSchedulerCreator>
132 &func_name_with_batching_scheduler_creators,
133 std::unique_ptr<SavedModel> saved_model,
134 std::unique_ptr<SavedModel> *result) {
135 if (saved_model ==
nullptr) {
136 return errors::FailedPrecondition(
"saved_model must not be null.");
139 SavedModel *raw_saved_model = saved_model.get();
140 std::unique_ptr<SavedModelWithBatching> saved_model_with_batching =
141 absl::make_unique<SavedModelWithBatching>(options,
142 std::move(saved_model));
143 SavedModelWithBatching *raw_saved_model_with_batching =
144 saved_model_with_batching.get();
146 for (
const auto &entry : func_name_with_batching_scheduler_creators) {
147 if (!raw_saved_model->GetFunctionMetadata(entry.func_name)) {
148 LOG(WARNING) <<
"Function " << entry.func_name
149 <<
" is not found in the model. ";
153 auto insert_result = saved_model_with_batching->batch_schedulers_.emplace(
154 std::string(entry.func_name),
nullptr);
155 if (!insert_result.second) {
156 return errors::FailedPrecondition(
157 absl::StrCat(
"Specified multiple batch schedulers for function ",
161 const std::string &func_name = insert_result.first->first;
162 TF_RETURN_IF_ERROR(entry.scheduler_creator(
163 [func_name, raw_saved_model_with_batching](
164 std::unique_ptr<Batch<SavedModelBatchingTask>> batch) {
165 raw_saved_model_with_batching->ProcessBatch(func_name,
168 &insert_result.first->second));
169 if (insert_result.first->second ==
nullptr) {
170 return errors::FailedPrecondition(absl::StrCat(
171 "Failed to create batch scheduler for function ", entry.func_name));
174 *result = std::move(saved_model_with_batching);
178 Status SavedModelWithBatching::Run(
179 const tfrt::SavedModel::RunOptions &run_options,
180 absl::string_view func_name, absl::Span<const Tensor> inputs,
181 std::vector<Tensor> *outputs) {
182 if (outputs ==
nullptr) {
183 return errors::FailedPrecondition(
"outputs must not be null");
185 auto it = batch_schedulers_.find(func_name);
186 if (it == batch_schedulers_.end()) {
189 static uint64_t last_log_message_secs = 0;
193 uint64_t now_secs = EnvTime::NowSeconds();
194 if (now_secs - last_log_message_secs >= 120) {
195 LOG(WARNING) <<
"Request doesn't match any declared function. Bypassing "
196 "batcher. Request function is: "
198 last_log_message_secs = now_secs;
200 return wrapped_->Run(run_options, func_name, inputs, outputs);
206 auto task = absl::make_unique<SavedModelBatchingTask>();
207 TF_RETURN_IF_ERROR(ComputeTensorBatchSize(
208 inputs, &task->zeroth_dim_size,
209 [](
const Tensor &tensor) { return tensor.dims(); },
210 [](
const Tensor &tensor,
size_t dim) { return tensor.dim_size(dim); }));
211 RecordInputBatchSize<SavedModelBatchingTask>(task->zeroth_dim_size);
213 task->host_context = GetHostContext();
214 task->tfrt_inputs = inputs;
215 task->tfrt_outputs = outputs;
217 task->status = &status;
218 task->run_options = run_options;
219 task->enqueue_time_micros = EnvTime::NowMicros();
221 TF_RETURN_IF_ERROR(it->second->Schedule(&task));
222 done.WaitForNotification();
229 std::vector<absl::InlinedVector<int, 4>> CalculateMaxDimSizes(
230 const Batch<SavedModelBatchingTask> &batch) {
231 std::vector<absl::InlinedVector<int, 4>> max_dim_sizes;
232 for (
int batch_idx = 0; batch_idx < batch.num_tasks(); ++batch_idx) {
233 const auto inputs = batch.task(batch_idx).tfrt_inputs;
234 for (
int tensor_idx = 0; tensor_idx < inputs.size(); ++tensor_idx) {
235 const Tensor &tensor = inputs[tensor_idx];
236 const TensorShape &shape = tensor.shape();
237 const int rank = shape.dims();
239 absl::InlinedVector<int, 4> dims;
241 for (
auto dim : shape) {
242 dims.push_back(dim.size);
245 if (batch_idx == 0) {
246 max_dim_sizes.push_back(std::move(dims));
248 for (
int rank_idx = 0; rank_idx < rank; ++rank_idx) {
249 int &cur_max_size = max_dim_sizes[tensor_idx][rank_idx];
250 cur_max_size = std::max(cur_max_size, dims[rank_idx]);
255 return max_dim_sizes;
258 Status SavedModelWithBatching::BatchInputTensors(
259 absl::string_view func_name,
const Batch<SavedModelBatchingTask> &batch,
260 std::vector<Tensor> *batch_inputs) {
261 if (batch.num_tasks() < 1) {
262 return errors::Internal(
"Batch size expected to be positive; was ",
265 const int original_batch_size = batch.size();
266 const int target_batch_size = RoundToLowestAllowedBatchSize(
267 options_.allowed_batch_sizes, original_batch_size);
268 const int padding_size = target_batch_size - original_batch_size;
269 RecordPaddingSize<SavedModelBatchingTask>(padding_size, target_batch_size);
270 RecordProcessedBatchSize<SavedModelBatchingTask>(target_batch_size);
272 std::vector<absl::InlinedVector<int, 4>> max_dim_sizes;
273 if (options_.pad_variable_length_inputs) {
274 max_dim_sizes = CalculateMaxDimSizes(batch);
280 std::vector<std::vector<Tensor>> tensors_to_merge(
281 batch.task(0).tfrt_inputs.size(), std::vector<Tensor>());
282 for (
int batch_idx = 0; batch_idx < batch.num_tasks(); ++batch_idx) {
283 auto inputs = batch.task(batch_idx).tfrt_inputs;
285 for (
int tensor_idx = 0; tensor_idx < inputs.size(); ++tensor_idx) {
286 Tensor tensor = inputs[tensor_idx];
287 std::vector<Tensor> &tensor_vec = tensors_to_merge[tensor_idx];
289 Tensor optionally_padded_tensor;
290 if (options_.pad_variable_length_inputs) {
291 TF_RETURN_IF_ERROR(AddPadding(tensor, max_dim_sizes[tensor_idx],
292 &optionally_padded_tensor));
294 optionally_padded_tensor = tensor;
296 TensorShape reference_shape = tensors_to_merge[tensor_idx][0].shape();
298 if (!AreShapesEqualExceptZeroDim(tensor.shape(), reference_shape)) {
299 return errors::FailedPrecondition(
300 " Tensors in a single batch have different shapes other than"
301 " first dimension and padding is turned off.");
305 tensor_vec.push_back(std::move(optionally_padded_tensor));
307 if (batch_idx == batch.num_tasks() - 1 && padding_size > 0) {
308 const Tensor padding_tensor = tensor_vec.back().Slice(0, 1);
309 for (
int i = 0; i < padding_size; ++i) {
310 tensor_vec.push_back(padding_tensor);
316 for (
const auto &tensors : tensors_to_merge) {
318 TF_RETURN_IF_ERROR(tensor::Concat(tensors, &concated));
319 batch_inputs->push_back(concated);
325 void SavedModelWithBatching::ProcessBatch(
326 absl::string_view func_name,
327 std::unique_ptr<Batch<SavedModelBatchingTask>> batch) {
328 batch->WaitUntilClosed();
330 if (batch->empty())
return;
331 Status status = Status();
332 auto cleanup = gtl::MakeCleanup([&status, &batch] {
333 for (
int batch_idx = 0; batch_idx < batch->num_tasks(); ++batch_idx) {
334 SavedModelBatchingTask *task = batch->mutable_task(batch_idx);
335 if (task->partial_status !=
nullptr) {
336 task->partial_status->Update(status);
337 task->done_callback();
339 *(task->status) = status;
340 task->done->Notify();
345 const uint64_t dequeue_time_micros = EnvTime::NowMicros();
347 bool all_tasks_timeout_exceeded =
true;
348 absl::optional<std::chrono::system_clock::time_point> batch_deadline;
349 for (
int batch_idx = 0; batch_idx < batch->num_tasks(); ++batch_idx) {
350 const SavedModelBatchingTask &task = batch->task(batch_idx);
351 if (!task.run_options.deadline.has_value() ||
352 absl::ToChronoTime(absl::Now()) < task.run_options.deadline.value()) {
353 all_tasks_timeout_exceeded =
false;
354 if (task.run_options.deadline.has_value() &&
355 (!batch_deadline.has_value() ||
356 batch_deadline.value() < task.run_options.deadline.value())) {
357 batch_deadline = task.run_options.deadline;
360 queuing_latency->GetCell()->Add(dequeue_time_micros -
361 task.enqueue_time_micros);
364 if (all_tasks_timeout_exceeded) {
366 static_cast<absl::StatusCode
>(absl::StatusCode::kResourceExhausted),
367 "Run() timeout exceeded while waiting in batching queue");
371 tfrt::SavedModel::RunOptions batch_run_options;
372 batch_run_options.deadline = batch_deadline;
373 std::vector<Tensor> batch_inputs;
374 status = BatchInputTensors(func_name, *batch, &batch_inputs);
375 if (!status.ok())
return;
377 std::vector<Tensor> combined_outputs;
378 status = wrapped_->Run(batch_run_options, func_name, batch_inputs,
380 if (!status.ok())
return;
381 status = SplitOutputTensors(std::move(combined_outputs), batch.get());
384 Status SavedModelWithBatching::SplitOutputTensors(
385 std::vector<Tensor> combined_outputs,
386 Batch<SavedModelBatchingTask> *batch) {
387 std::vector<int64_t> split_batch_sizes;
388 split_batch_sizes.reserve(batch->num_tasks());
389 for (
int batch_idx = 0; batch_idx < batch->num_tasks(); ++batch_idx) {
390 split_batch_sizes.push_back(batch->task(batch_idx).size());
392 const int64_t no_padded_batch_size = batch->size();
393 const int64_t padded_batch_size = RoundToLowestAllowedBatchSize(
394 options_.allowed_batch_sizes, no_padded_batch_size);
396 const int64_t padding_size = padded_batch_size - no_padded_batch_size;
397 if (padding_size > 0) {
398 split_batch_sizes.push_back(padding_size);
401 for (
const auto &combined_tensor : combined_outputs) {
402 std::vector<Tensor> split_tensors;
404 tensor::Split(combined_tensor, split_batch_sizes, &split_tensors));
406 for (
int batch_idx = 0; batch_idx < batch->num_tasks(); ++batch_idx) {
407 SavedModelBatchingTask *task = batch->mutable_task(batch_idx);
408 task->tfrt_outputs->push_back(split_tensors.at(batch_idx));
416 Status CreateSavedModelWithBatching(
417 const SavedModelBatchingOptions &options,
418 const std::vector<FuncNameWithBatchingSchedulerCreator>
419 &func_name_with_batching_scheduler_creator,
420 std::unique_ptr<tfrt::SavedModel> saved_model,
421 std::unique_ptr<tfrt::SavedModel> *saved_model_with_batching) {
422 return SavedModelWithBatching::Create(
423 options, func_name_with_batching_scheduler_creator,
424 std::move(saved_model), saved_model_with_batching);
427 Status SplitSavedModelInputTask(
428 std::unique_ptr<SavedModelBatchingTask> *input_task_ptr,
429 int open_batch_remaining_slot,
int max_batch_size,
430 std::vector<std::unique_ptr<SavedModelBatchingTask>> *output_tasks) {
431 SavedModelBatchingTask *input_task = input_task_ptr->get();
439 std::make_shared<std::vector<std::unique_ptr<std::vector<Tensor>>>>();
440 auto partial_status = std::make_shared<ThreadSafeStatus>();
442 auto split_task_done_callback = [split_output, partial_status,
443 status = input_task->status,
444 output = input_task->tfrt_outputs,
445 done_notification = input_task->done]() {
446 auto cleanup = gtl::MakeCleanup(
447 [done_notification]() { done_notification->Notify(); });
450 if (!partial_status->status().ok()) {
451 *status = partial_status->status();
455 int output_size = split_output->size();
456 int tensor_size = (*split_output)[0]->size();
457 for (
int tensor_idx = 0; tensor_idx < tensor_size; ++tensor_idx) {
458 Tensor output_tensor;
459 std::vector<Tensor> to_concatenate;
460 to_concatenate.reserve(output_size);
461 for (
int output_idx = 0; output_idx < output_size; ++output_idx) {
462 to_concatenate.push_back(
463 std::move((*(*split_output)[output_idx])[tensor_idx]));
465 const auto concat_status = tensor::Concat(to_concatenate, &output_tensor);
466 if (!concat_status.ok()) {
467 *status = concat_status;
470 output->push_back(output_tensor);
476 IncrementalBarrier barrier(std::move(split_task_done_callback));
477 std::vector<int64_t> output_task_sizes;
479 if (open_batch_remaining_slot > 0) {
480 output_task_sizes.push_back(open_batch_remaining_slot);
481 split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
484 for (
int left_task_size = input_task->size() - open_batch_remaining_slot;
485 left_task_size > 0; left_task_size -= max_batch_size) {
486 int next_task_size = std::min(left_task_size, max_batch_size);
487 output_task_sizes.push_back(next_task_size);
488 split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
491 const int output_task_num = output_task_sizes.size();
494 output_tasks->reserve(output_task_num);
495 for (
int i = 0; i < output_task_num; ++i) {
496 auto task = absl::make_unique<SavedModelBatchingTask>();
497 task->zeroth_dim_size = output_task_sizes[i];
498 task->run_options = input_task->run_options;
499 task->tfrt_outputs = (*split_output)[i].get();
500 task->done_callback = barrier.Inc();
501 task->partial_status = partial_status.get();
502 output_tasks->push_back(std::move(task));
505 for (
const Tensor &input : input_task->tfrt_inputs) {
506 std::vector<Tensor> split_tensors;
507 TF_RETURN_IF_ERROR(tensor::Split(input, output_task_sizes, &split_tensors));
508 for (
int output_idx = 0; output_idx < output_task_num; ++output_idx) {
509 auto &output_task = (*output_tasks)[output_idx];
510 output_task->tfrt_partial_inputs.push_back(split_tensors[output_idx]);
514 for (
auto &task : *output_tasks) {
515 task->tfrt_inputs = task->tfrt_partial_inputs;