TensorFlow Serving C++ API Documentation
incremental_barrier_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/batching/incremental_barrier.h"
17 
18 #include <atomic>
19 
20 #include "absl/functional/bind_front.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/platform/platform.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/platform/threadpool.h"
29 
30 namespace tensorflow {
31 namespace serving {
32 namespace {
33 
34 // A thread-safe counter class.
35 class Counter {
36  public:
37  void Increment() TF_LOCKS_EXCLUDED(mu_) {
38  mutex_lock l(mu_);
39  ++count_;
40  }
41 
42  int GetCount() TF_LOCKS_EXCLUDED(mu_) {
43  mutex_lock l(mu_);
44  return count_;
45  }
46 
47  private:
48  mutex mu_;
49  int count_ = 0;
50 };
51 
52 TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
53  Counter counter;
54  EXPECT_EQ(counter.GetCount(), 0);
55  {
56  IncrementalBarrier::DoneCallback done_callback =
57  absl::bind_front(&Counter::Increment, &counter);
58  IncrementalBarrier barrier(done_callback);
59  EXPECT_EQ(counter.GetCount(), 0);
60  }
61  EXPECT_EQ(counter.GetCount(), 1);
62 }
63 
64 TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
65  Counter counter;
66 
67  IncrementalBarrier::BarrierCallback bc1, bc2;
68  {
69  IncrementalBarrier::DoneCallback done_callback =
70  absl::bind_front(&Counter::Increment, &counter);
71  IncrementalBarrier barrier(done_callback);
72 
73  CHECK_EQ(counter.GetCount(), 0);
74 
75  bc1 = barrier.Inc();
76  bc2 = barrier.Inc();
77 
78  IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
79  bc3();
80 
81  CHECK_EQ(counter.GetCount(), 0);
82  }
83 
84  CHECK_EQ(counter.GetCount(), 0);
85  bc1();
86  CHECK_EQ(counter.GetCount(), 0);
87  bc2();
88  CHECK_EQ(counter.GetCount(), 1);
89 }
90 
91 TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
92  const int num_closure = 100, num_thread = 2;
93  std::atomic<int> schedule_count{0};
94  Counter counter;
95 
96  {
97  IncrementalBarrier::DoneCallback done_callback =
98  absl::bind_front(&Counter::Increment, &counter);
99  IncrementalBarrier barrier(done_callback);
100 
101  CHECK_EQ(counter.GetCount(), 0);
102 
103  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
104  "BarrierClosure", num_thread);
105  for (int i = 0; i < num_closure; ++i) {
106  pool.Schedule([&barrier, &schedule_count]() {
107  schedule_count.fetch_add(1);
108  IncrementalBarrier::BarrierCallback bc = barrier.Inc();
109 
110  Env::Default()->SleepForMicroseconds(100);
111  bc();
112  });
113  }
114 
115  CHECK_EQ(counter.GetCount(), 0);
116  }
117 
118  CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
119  CHECK_EQ(counter.GetCount(), 1);
120 }
121 
122 #if defined(PLATFORM_GOOGLE)
123 void BM_FunctionInc(benchmark::State& state) {
124  IncrementalBarrier barrier([] {});
125  for (auto _ : state) {
126  barrier.Inc()();
127  }
128 }
129 
130 BENCHMARK(BM_FunctionInc);
131 #endif // PLATFORM_GOOGLE
132 
133 } // namespace
134 } // namespace serving
135 } // namespace tensorflow