16 #include "tensorflow_serving/core/source_router.h"
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/storage_path.h"
31 #include "tensorflow_serving/core/target.h"
32 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h"
35 using ::testing::ElementsAre;
37 using ::testing::IsEmpty;
38 using ::testing::StrictMock;
40 namespace tensorflow {
44 class TestSourceRouter final :
public SourceRouter<StoragePath> {
46 TestSourceRouter(
int num_ports = 2) : num_output_ports_(num_ports) {}
47 ~TestSourceRouter()
override { Detach(); }
50 const int num_output_ports_;
51 int num_output_ports()
const override {
return num_output_ports_; }
53 int Route(
const StringPiece servable_name,
54 const std::vector<ServableData<StoragePath>>& versions)
override {
55 if (servable_name ==
"zero") {
57 }
else if (servable_name ==
"one") {
59 }
else if (servable_name ==
"no_route") {
62 LOG(FATAL) <<
"Unexpected test data";
67 TF_DISALLOW_COPY_AND_ASSIGN(TestSourceRouter);
70 TEST(SourceRouterTest, Basic) {
71 TestSourceRouter router;
72 std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
73 ASSERT_EQ(2, output_ports.size());
74 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
75 for (
int i = 0; i < output_ports.size(); ++i) {
76 std::unique_ptr<test_util::MockStoragePathTarget> target(
77 new StrictMock<test_util::MockStoragePathTarget>);
78 ConnectSourceToTarget(output_ports[i], target.get());
79 targets.push_back(std::move(target));
82 EXPECT_CALL(*targets[0],
85 ElementsAre(ServableData<StoragePath>({
"zero", 0},
"mrop"))));
86 router.SetAspiredVersions(
"zero",
87 {ServableData<StoragePath>({
"zero", 0},
"mrop")});
89 EXPECT_CALL(*targets[1], SetAspiredVersions(
90 Eq(
"one"), ElementsAre(ServableData<StoragePath>(
91 {
"one", 1},
"floo"))));
92 router.SetAspiredVersions(
"one",
93 {ServableData<StoragePath>({
"one", 1},
"floo")});
96 TEST(SourceRouterTest, NumPorts) {
97 TestSourceRouter router(1);
98 std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
99 ASSERT_EQ(1, output_ports.size());
100 std::unique_ptr<test_util::MockStoragePathTarget> target(
101 new StrictMock<test_util::MockStoragePathTarget>);
102 ConnectSourceToTarget(output_ports[0], target.get());
104 EXPECT_CALL(*target, SetAspiredVersions(Eq(
"zero"),
105 ElementsAre(ServableData<StoragePath>(
106 {
"zero", 0},
"mrop"))));
107 router.SetAspiredVersions(
"zero",
108 {ServableData<StoragePath>({
"zero", 0},
"mrop")});
111 TEST(SourceRouterTest, SetAspiredVersionsBlocksUntilAllTargetsConnected_1) {
116 TestSourceRouter router;
120 std::unique_ptr<Thread> connect_targets(Env::Default()->StartThread(
121 {},
"ConnectTargets",
126 Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 );
128 std::vector<Source<StoragePath>*> output_ports =
129 router.GetOutputPorts();
130 ASSERT_EQ(2, output_ports.size());
131 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
132 for (
int i = 0; i < output_ports.size(); ++i) {
133 std::unique_ptr<test_util::MockStoragePathTarget> target(
134 new StrictMock<test_util::MockStoragePathTarget>);
135 EXPECT_CALL(*target, SetAspiredVersions(_, IsEmpty()));
136 ConnectSourceToTarget(output_ports[i], target.get());
137 targets.push_back(std::move(target));
139 done.WaitForNotification();
142 router.SetAspiredVersions(
"zero", {});
143 router.SetAspiredVersions(
"one", {});
148 TEST(SourceRouterTest, SetAspiredVersionsBlocksUntilAllTargetsConnected_2) {
153 TestSourceRouter router;
154 std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
155 ASSERT_EQ(2, output_ports.size());
156 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
157 for (
int i = 0; i < output_ports.size(); ++i) {
158 std::unique_ptr<test_util::MockStoragePathTarget> target(
159 new StrictMock<test_util::MockStoragePathTarget>);
160 targets.push_back(std::move(target));
164 ConnectSourceToTarget(output_ports[0], targets[0].get());
167 std::unique_ptr<Thread> connect_target_1(Env::Default()->StartThread(
168 {},
"ConnectTarget1",
169 [&output_ports, &targets] {
173 Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 );
174 ConnectSourceToTarget(output_ports[1], targets[1].get());
177 EXPECT_CALL(*targets[1], SetAspiredVersions(Eq(
"one"), IsEmpty()));
178 router.SetAspiredVersions(
"one", {});
181 TEST(SourceRouterTest, DiscardRequest) {
184 TestSourceRouter router;
185 std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
186 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
187 for (
int i = 0; i < output_ports.size(); ++i) {
188 targets.emplace_back(
new StrictMock<test_util::MockStoragePathTarget>);
189 ConnectSourceToTarget(output_ports[i], targets.back().get());
192 router.SetAspiredVersions(
"no_route", {});