16 #include "tensorflow_serving/core/dynamic_source_router.h"
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/source.h"
31 #include "tensorflow_serving/core/storage_path.h"
32 #include "tensorflow_serving/core/target.h"
33 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h"
35 using ::testing::ElementsAre;
37 using ::testing::StrictMock;
39 namespace tensorflow {
43 TEST(DynamicSourceRouterTest, InvalidRouteMap) {
44 std::unique_ptr<DynamicSourceRouter<StoragePath>> router;
47 DynamicSourceRouter<string>::Create(2, {{
"foo", -1}}, &router).ok());
50 DynamicSourceRouter<string>::Create(2, {{
"foo", 2}}, &router).ok());
53 DynamicSourceRouter<string>::Create(2, {{
"foo", 1}}, &router).ok());
56 TEST(DynamicSourceRouterTest, ReconfigureToInvalidRouteMap) {
57 std::unique_ptr<DynamicSourceRouter<StoragePath>> router;
58 TF_ASSERT_OK(DynamicSourceRouter<string>::Create(2, {{
"foo", 0}}, &router));
60 EXPECT_FALSE(router->UpdateRoutes({{
"foo", -1}}).ok());
62 EXPECT_FALSE(router->UpdateRoutes({{
"foo", 2}}).ok());
64 EXPECT_FALSE(router->UpdateRoutes({{
"foo", 1}}).ok());
67 TEST(DynamicSourceRouterTest, Basic) {
68 std::unique_ptr<DynamicSourceRouter<StoragePath>> router;
69 DynamicSourceRouter<StoragePath>::Routes routes = {{
"foo", 0}, {
"bar", 1}};
70 TF_ASSERT_OK(DynamicSourceRouter<string>::Create(4, routes, &router));
71 EXPECT_EQ(routes, router->GetRoutes());
72 std::vector<Source<StoragePath>*> output_ports = router->GetOutputPorts();
73 ASSERT_EQ(4, output_ports.size());
74 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
75 for (
int i = 0; i < output_ports.size(); ++i) {
76 std::unique_ptr<test_util::MockStoragePathTarget> target(
77 new StrictMock<test_util::MockStoragePathTarget>);
78 ConnectSourceToTarget(output_ports[i], target.get());
79 targets.push_back(std::move(target));
83 EXPECT_CALL(*targets[0], SetAspiredVersions(
84 Eq(
"foo"), ElementsAre(ServableData<StoragePath>(
85 {
"foo", 7},
"data"))));
86 router->SetAspiredVersions(
"foo",
87 {ServableData<StoragePath>({
"foo", 7},
"data")});
90 EXPECT_CALL(*targets[1], SetAspiredVersions(
91 Eq(
"bar"), ElementsAre(ServableData<StoragePath>(
92 {
"bar", 7},
"data"))));
93 router->SetAspiredVersions(
"bar",
94 {ServableData<StoragePath>({
"bar", 7},
"data")});
97 EXPECT_CALL(*targets[3],
98 SetAspiredVersions(Eq(
"not_foo_or_bar"),
99 ElementsAre(ServableData<StoragePath>(
100 {
"not_foo_or_bar", 7},
"data"))));
101 router->SetAspiredVersions(
103 {ServableData<StoragePath>({
"not_foo_or_bar", 7},
"data")});
106 TEST(DynamicSourceRouterTest, Reconfigure) {
107 std::unique_ptr<DynamicSourceRouter<StoragePath>> router;
108 TF_ASSERT_OK(DynamicSourceRouter<string>::Create(2, {{
"foo", 0}}, &router));
109 std::vector<Source<StoragePath>*> output_ports = router->GetOutputPorts();
110 ASSERT_EQ(2, output_ports.size());
111 std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
112 for (
int i = 0; i < output_ports.size(); ++i) {
113 std::unique_ptr<test_util::MockStoragePathTarget> target(
114 new StrictMock<test_util::MockStoragePathTarget>);
115 ConnectSourceToTarget(output_ports[i], target.get());
116 targets.push_back(std::move(target));
120 EXPECT_CALL(*targets[0], SetAspiredVersions(
121 Eq(
"foo"), ElementsAre(ServableData<StoragePath>(
122 {
"foo", 7},
"data"))));
123 router->SetAspiredVersions(
"foo",
124 {ServableData<StoragePath>({
"foo", 7},
"data")});
125 EXPECT_CALL(*targets[1], SetAspiredVersions(
126 Eq(
"bar"), ElementsAre(ServableData<StoragePath>(
127 {
"bar", 7},
"data"))));
128 router->SetAspiredVersions(
"bar",
129 {ServableData<StoragePath>({
"bar", 7},
"data")});
131 TF_ASSERT_OK(router->UpdateRoutes({{
"bar", 0}}));
134 EXPECT_CALL(*targets[1], SetAspiredVersions(
135 Eq(
"foo"), ElementsAre(ServableData<StoragePath>(
136 {
"foo", 7},
"data"))));
137 router->SetAspiredVersions(
"foo",
138 {ServableData<StoragePath>({
"foo", 7},
"data")});
139 EXPECT_CALL(*targets[0], SetAspiredVersions(
140 Eq(
"bar"), ElementsAre(ServableData<StoragePath>(
141 {
"bar", 7},
"data"))));
142 router->SetAspiredVersions(
"bar",
143 {ServableData<StoragePath>({
"bar", 7},
"data")});