16 #include "tensorflow_serving/batching/streaming_batch_scheduler.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 using ::testing::ElementsAre;
29 using ::testing::IsEmpty;
30 using ::testing::UnorderedElementsAre;
32 namespace tensorflow {
36 class FakeTask :
public BatchTask {
38 explicit FakeTask(
size_t size) : size_(size) {}
40 ~FakeTask()
override =
default;
42 size_t size()
const override {
return size_; }
47 TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
52 Status ScheduleTask(
size_t task_size, BatchScheduler<FakeTask>* scheduler) {
53 std::unique_ptr<FakeTask> task(
new FakeTask(task_size));
54 Status status = scheduler->Schedule(&task);
56 CHECK_EQ(status.ok(), task ==
nullptr);
60 TEST(StreamingBatchSchedulerTest, Basic) {
61 bool callback_called =
false;
62 auto callback = [&callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
63 callback_called =
true;
64 batch->WaitUntilClosed();
65 ASSERT_EQ(2, batch->num_tasks());
66 EXPECT_EQ(3, batch->task(0).size());
67 EXPECT_EQ(5, batch->task(1).size());
70 StreamingBatchScheduler<FakeTask>::Options options;
71 options.max_batch_size = 10;
72 options.batch_timeout_micros = 100 * 1000;
73 options.num_batch_threads = 1;
74 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
75 TF_ASSERT_OK(StreamingBatchScheduler<FakeTask>::Create(options, callback,
77 TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
78 TF_ASSERT_OK(ScheduleTask(5, scheduler.get()));
80 EXPECT_TRUE(callback_called);
83 TEST(StreamingBatchSchedulerTest, ObeyBatchSizeConstraint) {
86 std::vector<std::vector<size_t>> callback_data;
88 &callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
89 batch->WaitUntilClosed();
90 std::vector<size_t> batch_data;
91 for (
int i = 0; i < batch->num_tasks(); ++i) {
92 batch_data.push_back(batch->mutable_task(i)->size());
96 callback_data.push_back(batch_data);
102 StreamingBatchScheduler<FakeTask>::Options options;
103 options.max_batch_size = 10;
104 options.batch_timeout_micros = 100 * 1000;
105 options.num_batch_threads = 2;
106 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
107 TF_ASSERT_OK(StreamingBatchScheduler<FakeTask>::Create(options, callback,
111 TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
112 TF_ASSERT_OK(ScheduleTask(5, scheduler.get()));
115 TF_ASSERT_OK(ScheduleTask(3 , scheduler.get()));
116 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
117 TF_ASSERT_OK(ScheduleTask(6, scheduler.get()));
125 UnorderedElementsAre(ElementsAre(3, 5), ElementsAre(3, 1, 6), IsEmpty()));
128 TEST(StreamingBatchSchedulerTest, Timeout) {
130 test_util::FakeClockEnv env(Env::Default());
132 Notification first_batch_processed, second_batch_processed,
133 third_batch_processed;
134 auto callback = [&first_batch_processed, &second_batch_processed,
135 &third_batch_processed](
136 std::unique_ptr<Batch<FakeTask>> batch) {
137 batch->WaitUntilClosed();
138 if (batch->size() == 1) {
139 first_batch_processed.Notify();
140 }
else if (batch->size() == 2) {
141 second_batch_processed.Notify();
142 }
else if (batch->size() == 3) {
143 third_batch_processed.Notify();
147 StreamingBatchScheduler<FakeTask>::Options options;
148 options.max_batch_size = 4;
149 options.batch_timeout_micros = 10;
150 options.num_batch_threads = 10;
153 options.no_tasks_wait_time_micros = 0;
154 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
156 StreamingBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
160 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
161 env.BlockUntilSleepingThread(10);
162 env.AdvanceByMicroseconds(9);
163 Env::Default()->SleepForMicroseconds(10 * 1000 );
164 EXPECT_FALSE(first_batch_processed.HasBeenNotified());
165 env.AdvanceByMicroseconds(1);
166 first_batch_processed.WaitForNotification();
171 TF_ASSERT_OK(ScheduleTask(2, scheduler.get()));
172 env.BlockUntilSleepingThread(20);
173 env.AdvanceByMicroseconds(9);
174 Env::Default()->SleepForMicroseconds(10 * 1000 );
175 EXPECT_FALSE(second_batch_processed.HasBeenNotified());
176 TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
177 second_batch_processed.WaitForNotification();
181 env.AdvanceByMicroseconds(9);
182 Env::Default()->SleepForMicroseconds(10 * 1000 );
183 EXPECT_FALSE(third_batch_processed.HasBeenNotified());
184 env.BlockUntilSleepingThread(29);
185 env.AdvanceByMicroseconds(1);
186 third_batch_processed.WaitForNotification();
189 TEST(StreamingBatchSchedulerTest, RealClockTimeout) {
190 Notification first_batch_processed, second_batch_processed;
191 auto callback = [&first_batch_processed, &second_batch_processed](
192 std::unique_ptr<Batch<FakeTask>> batch) {
193 batch->WaitUntilClosed();
194 if (batch->size() == 1) {
195 first_batch_processed.Notify();
196 }
else if (batch->size() == 2) {
197 second_batch_processed.Notify();
201 StreamingBatchScheduler<FakeTask>::Options options;
202 options.max_batch_size = 10;
203 options.batch_timeout_micros = 10 * 1000;
204 options.num_batch_threads = 10;
205 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
207 StreamingBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
211 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
212 first_batch_processed.WaitForNotification();
215 TF_ASSERT_OK(ScheduleTask(2, scheduler.get()));
216 second_batch_processed.WaitForNotification();
219 TEST(StreamingBatchSchedulerTest, FinalUnderfullBatchProcessedUponDeletion) {
220 bool callback_called =
false;
221 auto callback = [&callback_called](std::unique_ptr<Batch<FakeTask>> batch) {
222 batch->WaitUntilClosed();
223 callback_called =
true;
227 StreamingBatchScheduler<FakeTask>::Options options;
228 options.max_batch_size = 10;
229 options.batch_timeout_micros = 100 * 1000;
230 options.num_batch_threads = 1;
231 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
232 TF_ASSERT_OK(StreamingBatchScheduler<FakeTask>::Create(options, callback,
237 TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
239 EXPECT_TRUE(callback_called);
242 TEST(StreamingBatchSchedulerTest, BatchHandedToCallbackWhenFirstCreated) {
243 Notification stop_scheduler;
244 auto callback = [&stop_scheduler](std::unique_ptr<Batch<FakeTask>> batch) {
245 EXPECT_LE(batch->num_tasks(), 1);
246 EXPECT_FALSE(batch->IsClosed());
247 stop_scheduler.Notify();
248 batch->WaitUntilClosed();
251 StreamingBatchScheduler<FakeTask>::Options options;
252 options.max_batch_size = 100;
253 options.batch_timeout_micros = 100 * 1000;
254 options.num_batch_threads = 1;
255 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
257 StreamingBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
260 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
262 stop_scheduler.WaitForNotification();
265 TEST(StreamingBatchSchedulerTest, ConstMethods) {
266 for (
const int num_threads : {1, 2, 3}) {
267 Notification proceed;
268 auto callback = [&proceed](std::unique_ptr<Batch<FakeTask>> batch) {
269 batch->WaitUntilClosed();
270 proceed.WaitForNotification();
273 StreamingBatchScheduler<FakeTask>::Options options;
274 options.max_batch_size = 2;
275 options.batch_timeout_micros = 1 * 1000 * 1000;
276 options.num_batch_threads = num_threads;
277 std::unique_ptr<StreamingBatchScheduler<FakeTask>> scheduler;
278 TF_ASSERT_OK(StreamingBatchScheduler<FakeTask>::Create(options, callback,
281 EXPECT_EQ(2, scheduler->max_task_size());
286 for (
int i = 0; i < num_threads; ++i) {
287 EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
288 EXPECT_EQ((num_threads - i) * 2, scheduler->SchedulingCapacity());
289 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
290 EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
291 EXPECT_EQ((num_threads - i) * 2 - 1, scheduler->SchedulingCapacity());
292 TF_ASSERT_OK(ScheduleTask(1, scheduler.get()));
294 EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
295 EXPECT_EQ(0, scheduler->SchedulingCapacity());
299 Status status = ScheduleTask(1, scheduler.get());
300 EXPECT_FALSE(status.ok());
301 EXPECT_EQ(error::UNAVAILABLE, status.code());
302 EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
303 EXPECT_EQ(0, scheduler->SchedulingCapacity());
308 Env::Default()->SleepForMicroseconds(100 * 1000 );
312 EXPECT_EQ(num_threads * 2, scheduler->SchedulingCapacity());
313 TF_EXPECT_OK(ScheduleTask(1, scheduler.get()));