16 #ifndef TENSORFLOW_SERVING_BATCHING_STREAMING_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_SERVING_BATCHING_STREAMING_BATCH_SCHEDULER_H_
26 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/notification.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/cpu_info.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/thread_annotations.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow_serving/batching/batch_scheduler_retrier.h"
40 namespace tensorflow {
43 class SingleTaskScheduler;
48 namespace tensorflow {
112 template <
typename TaskType>
125 size_t max_batch_size = 1000;
138 int64_t batch_timeout_micros = 0;
141 string thread_pool_name =
"batch_threads";
145 int num_batch_threads = port::MaxParallelism();
150 Env* env = Env::Default();
154 uint64_t no_tasks_wait_time_micros = 1000;
156 static Status Create(
158 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
159 process_batch_callback,
164 Status Schedule(std::unique_ptr<TaskType>* task)
override;
167 size_t NumEnqueuedTasks()
const override {
return 0; }
171 size_t SchedulingCapacity()
const override;
173 size_t max_task_size()
const override {
return options_.max_batch_size; }
176 StreamingBatchScheduler(
const Options& options,
177 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
178 process_batch_callback);
181 bool TaskFitsInBatch(
const TaskType* task,
182 const Batch<TaskType>* batch)
const;
186 void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
191 void ScheduleCloseOfCurrentOpenBatch(uint64_t close_time_micros)
192 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
194 const Options options_;
198 std::function<
void(std::unique_ptr<Batch<TaskType>>)> process_batch_callback_;
201 std::unique_ptr<thread::ThreadPool> batch_threads_;
208 Batch<TaskType>* open_batch_ TF_GUARDED_BY(mu_) =
nullptr;
212 int64_t open_batch_num_ TF_GUARDED_BY(mu_) = 0;
218 int num_batches_in_progress_ TF_GUARDED_BY(mu_) = 0;
222 std::unique_ptr<internal::SingleTaskScheduler> batch_closer_
225 TF_DISALLOW_COPY_AND_ASSIGN(StreamingBatchScheduler);
229 template <typename TaskType>
230 Status CreateRetryingStreamingBatchScheduler(
231 const typename StreamingBatchScheduler<TaskType>::Options& schedule_options,
232 const typename BatchSchedulerRetrier<TaskType>::Options& retry_options,
233 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
234 process_batch_callback,
235 std::unique_ptr<BatchScheduler<TaskType>>* scheduler);
249 uint64_t no_tasks_wait_time_micros);
259 void Schedule(uint64_t time_micros, std::function<
void()> closure);
274 uint64_t time_micros;
275 std::function<void()> closure;
279 absl::optional<Task> updated_task_ TF_GUARDED_BY(mu_);
283 uint64_t last_task_time_ = 0;
289 const string thread_name_;
292 std::unique_ptr<Thread> thread_;
295 const uint64_t no_tasks_wait_time_micros_;
302 template <
typename TaskType>
304 const Options& options,
305 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
306 process_batch_callback,
308 if (options.max_batch_size <= 0) {
309 return errors::InvalidArgument(
"max_batch_size must be positive; was ",
310 options.max_batch_size);
312 if (options.num_batch_threads <= 0) {
313 return errors::InvalidArgument(
"num_batch_threads must be positive; was ",
314 options.num_batch_threads);
317 new StreamingBatchScheduler<TaskType>(options, process_batch_callback));
321 template <
typename TaskType>
322 StreamingBatchScheduler<TaskType>::~StreamingBatchScheduler() {
325 if (open_batch_ !=
nullptr) {
326 open_batch_->Close();
327 open_batch_ =
nullptr;
333 batch_threads_.reset(
nullptr);
336 template <
typename TaskType>
337 Status StreamingBatchScheduler<TaskType>::Schedule(
338 std::unique_ptr<TaskType>* task) {
339 if ((*task)->size() > options_.max_batch_size) {
340 return errors::InvalidArgument(
"Task size ", (*task)->size(),
341 " is larger than maximum batch size ",
342 options_.max_batch_size);
348 if (open_batch_ ==
nullptr || !TaskFitsInBatch(task->get(), open_batch_)) {
355 if (num_batches_in_progress_ > options_.num_batch_threads) {
356 DCHECK(open_batch_->empty());
357 return errors::Unavailable(
358 "This task would start a fresh batch, but all batch threads are "
359 "busy, so at present there is no processing capacity available for "
365 if (options_.batch_timeout_micros > 0 && open_batch_->empty()) {
366 const uint64_t batch_deadline =
367 options_.env->NowMicros() + options_.batch_timeout_micros;
368 ScheduleCloseOfCurrentOpenBatch(batch_deadline);
371 open_batch_->AddTask(std::move(*task));
374 if (open_batch_->size() == options_.max_batch_size) {
382 template <
typename TaskType>
383 size_t StreamingBatchScheduler<TaskType>::SchedulingCapacity()
const {
385 if (num_batches_in_progress_ > options_.num_batch_threads) {
388 const int num_idle_threads =
389 options_.num_batch_threads - num_batches_in_progress_;
390 const int open_batch_capacity =
391 open_batch_ ==
nullptr ? 0
392 : options_.max_batch_size - open_batch_->size();
393 return (num_idle_threads * options_.max_batch_size) + open_batch_capacity;
396 template <
typename TaskType>
397 StreamingBatchScheduler<TaskType>::StreamingBatchScheduler(
398 const Options& options,
399 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
400 process_batch_callback)
402 process_batch_callback_(process_batch_callback),
403 batch_threads_(new thread::ThreadPool(options_.env,
404 options_.thread_pool_name,
405 options_.num_batch_threads)) {}
407 template <
typename TaskType>
408 bool StreamingBatchScheduler<TaskType>::TaskFitsInBatch(
409 const TaskType* task,
const Batch<TaskType>* batch)
const {
410 return batch->size() + task->size() <= options_.max_batch_size;
413 template <
typename TaskType>
414 void StreamingBatchScheduler<TaskType>::StartNewBatch() {
415 if (open_batch_ !=
nullptr) {
416 open_batch_->Close();
417 open_batch_ =
nullptr;
420 Batch<TaskType>* new_open_batch =
new Batch<TaskType>;
421 ++num_batches_in_progress_;
422 batch_threads_->Schedule([
this, new_open_batch] {
423 this->process_batch_callback_(
424 std::unique_ptr<Batch<TaskType>>(new_open_batch));
426 mutex_lock l(this->mu_);
427 --this->num_batches_in_progress_;
430 open_batch_ = new_open_batch;
434 template <
typename TaskType>
435 void StreamingBatchScheduler<TaskType>::ScheduleCloseOfCurrentOpenBatch(
436 uint64_t close_time_micros) {
437 if (batch_closer_ ==
nullptr) {
438 batch_closer_.reset(
new internal::SingleTaskScheduler(
439 options_.env,
"batch_closer", options_.no_tasks_wait_time_micros));
442 const int64_t batch_num_to_close = open_batch_num_;
443 batch_closer_->Schedule(close_time_micros, [
this, batch_num_to_close] {
445 mutex_lock l(this->mu_);
446 if (open_batch_num_ == batch_num_to_close) {
453 template <
typename TaskType>
454 Status CreateRetryingStreamingBatchScheduler(
455 const typename StreamingBatchScheduler<TaskType>::Options& schedule_options,
456 const typename BatchSchedulerRetrier<TaskType>::Options& retry_options,
457 std::function<
void(std::unique_ptr<Batch<TaskType>>)>
458 process_batch_callback,
459 std::unique_ptr<BatchScheduler<TaskType>>* scheduler) {
460 std::unique_ptr<StreamingBatchScheduler<TaskType>> streaming_scheduler;
461 TF_RETURN_IF_ERROR(StreamingBatchScheduler<TaskType>::Create(
462 schedule_options, process_batch_callback, &streaming_scheduler));
463 std::unique_ptr<BatchSchedulerRetrier<TaskType>> retrier;
464 TF_RETURN_IF_ERROR(BatchSchedulerRetrier<TaskType>::Create(
465 retry_options, std::move(streaming_scheduler), &retrier));
466 *scheduler = std::move(retrier);