16 #include "tensorflow_serving/batching/test_util/puppet_batch_scheduler.h"
20 #include <gtest/gtest.h>
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/macros.h"
25 namespace tensorflow {
30 class FakeTask :
public BatchTask {
33 ~FakeTask()
override =
default;
35 size_t size()
const override {
return 1; }
38 TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
43 void ScheduleTask(BatchScheduler<FakeTask>* scheduler) {
44 std::unique_ptr<FakeTask> task(
new FakeTask);
45 TF_ASSERT_OK(scheduler->Schedule(&task));
48 TEST(PuppetBatchSchedulerTest, Basic) {
49 int num_tasks_processed = 0;
51 [&num_tasks_processed](std::unique_ptr<Batch<FakeTask>> batch) {
52 ASSERT_TRUE(batch->IsClosed());
53 num_tasks_processed += batch->size();
55 PuppetBatchScheduler<FakeTask> scheduler(callback);
57 for (
int i = 0; i < 3; ++i) {
58 EXPECT_EQ(0, num_tasks_processed);
59 EXPECT_EQ(i, scheduler.NumEnqueuedTasks());
60 ScheduleTask(&scheduler);
62 EXPECT_EQ(3, scheduler.NumEnqueuedTasks());
64 scheduler.ProcessTasks(2);
65 EXPECT_EQ(2, num_tasks_processed);
66 EXPECT_EQ(1, scheduler.NumEnqueuedTasks());
68 ScheduleTask(&scheduler);
69 EXPECT_EQ(2, num_tasks_processed);
70 EXPECT_EQ(2, scheduler.NumEnqueuedTasks());
72 scheduler.ProcessAllTasks();
73 EXPECT_EQ(4, num_tasks_processed);
74 EXPECT_EQ(0, scheduler.NumEnqueuedTasks());