16 #include "tensorflow_serving/batching/incremental_barrier.h"
20 #include "absl/functional/bind_front.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/platform/platform.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/platform/threadpool.h"
30 namespace tensorflow {
37 void Increment() TF_LOCKS_EXCLUDED(mu_) {
42 int GetCount() TF_LOCKS_EXCLUDED(mu_) {
52 TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
54 EXPECT_EQ(counter.GetCount(), 0);
56 IncrementalBarrier::DoneCallback done_callback =
57 absl::bind_front(&Counter::Increment, &counter);
58 IncrementalBarrier barrier(done_callback);
59 EXPECT_EQ(counter.GetCount(), 0);
61 EXPECT_EQ(counter.GetCount(), 1);
64 TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
67 IncrementalBarrier::BarrierCallback bc1, bc2;
69 IncrementalBarrier::DoneCallback done_callback =
70 absl::bind_front(&Counter::Increment, &counter);
71 IncrementalBarrier barrier(done_callback);
73 CHECK_EQ(counter.GetCount(), 0);
78 IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
81 CHECK_EQ(counter.GetCount(), 0);
84 CHECK_EQ(counter.GetCount(), 0);
86 CHECK_EQ(counter.GetCount(), 0);
88 CHECK_EQ(counter.GetCount(), 1);
91 TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
92 const int num_closure = 100, num_thread = 2;
93 std::atomic<int> schedule_count{0};
97 IncrementalBarrier::DoneCallback done_callback =
98 absl::bind_front(&Counter::Increment, &counter);
99 IncrementalBarrier barrier(done_callback);
101 CHECK_EQ(counter.GetCount(), 0);
103 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
104 "BarrierClosure", num_thread);
105 for (
int i = 0; i < num_closure; ++i) {
106 pool.Schedule([&barrier, &schedule_count]() {
107 schedule_count.fetch_add(1);
108 IncrementalBarrier::BarrierCallback bc = barrier.Inc();
110 Env::Default()->SleepForMicroseconds(100);
115 CHECK_EQ(counter.GetCount(), 0);
118 CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
119 CHECK_EQ(counter.GetCount(), 1);
122 #if defined(PLATFORM_GOOGLE)
123 void BM_FunctionInc(benchmark::State& state) {
124 IncrementalBarrier barrier([] {});
125 for (
auto _ : state) {
130 BENCHMARK(BM_FunctionInc);