TensorFlow Serving C++ API Documentation
source_router_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/source_router.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/storage_path.h"
31 #include "tensorflow_serving/core/target.h"
32 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h"
33 
34 using ::testing::_;
35 using ::testing::ElementsAre;
36 using ::testing::Eq;
37 using ::testing::IsEmpty;
38 using ::testing::StrictMock;
39 
40 namespace tensorflow {
41 namespace serving {
42 namespace {
43 
44 class TestSourceRouter final : public SourceRouter<StoragePath> {
45  public:
46  TestSourceRouter(int num_ports = 2) : num_output_ports_(num_ports) {}
47  ~TestSourceRouter() override { Detach(); }
48 
49  protected:
50  const int num_output_ports_;
51  int num_output_ports() const override { return num_output_ports_; }
52 
53  int Route(const StringPiece servable_name,
54  const std::vector<ServableData<StoragePath>>& versions) override {
55  if (servable_name == "zero") {
56  return 0;
57  } else if (servable_name == "one") {
58  return 1;
59  } else if (servable_name == "no_route") {
60  return kNoRoute;
61  } else {
62  LOG(FATAL) << "Unexpected test data";
63  }
64  }
65 
66  private:
67  TF_DISALLOW_COPY_AND_ASSIGN(TestSourceRouter);
68 };
69 
70 TEST(SourceRouterTest, Basic) {
71  TestSourceRouter router;
72  std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
73  ASSERT_EQ(2, output_ports.size());
74  std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
75  for (int i = 0; i < output_ports.size(); ++i) {
76  std::unique_ptr<test_util::MockStoragePathTarget> target(
77  new StrictMock<test_util::MockStoragePathTarget>);
78  ConnectSourceToTarget(output_ports[i], target.get());
79  targets.push_back(std::move(target));
80  }
81 
82  EXPECT_CALL(*targets[0],
83  SetAspiredVersions(
84  Eq("zero"),
85  ElementsAre(ServableData<StoragePath>({"zero", 0}, "mrop"))));
86  router.SetAspiredVersions("zero",
87  {ServableData<StoragePath>({"zero", 0}, "mrop")});
88 
89  EXPECT_CALL(*targets[1], SetAspiredVersions(
90  Eq("one"), ElementsAre(ServableData<StoragePath>(
91  {"one", 1}, "floo"))));
92  router.SetAspiredVersions("one",
93  {ServableData<StoragePath>({"one", 1}, "floo")});
94 }
95 
96 TEST(SourceRouterTest, NumPorts) {
97  TestSourceRouter router(1);
98  std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
99  ASSERT_EQ(1, output_ports.size());
100  std::unique_ptr<test_util::MockStoragePathTarget> target(
101  new StrictMock<test_util::MockStoragePathTarget>);
102  ConnectSourceToTarget(output_ports[0], target.get());
103 
104  EXPECT_CALL(*target, SetAspiredVersions(Eq("zero"),
105  ElementsAre(ServableData<StoragePath>(
106  {"zero", 0}, "mrop"))));
107  router.SetAspiredVersions("zero",
108  {ServableData<StoragePath>({"zero", 0}, "mrop")});
109 }
110 
111 TEST(SourceRouterTest, SetAspiredVersionsBlocksUntilAllTargetsConnected_1) {
112  // Scenario 1: When SetAspiredVersions() is invoked, GetOutputPorts() has not
113  // yet been called. The SetAspiredVersions() call should block until the ports
114  // have been emitted and all of them have been connected to targets.
115 
116  TestSourceRouter router;
117  Notification done;
118 
119  // Connect the output ports to targets asynchronously, after a long delay.
120  std::unique_ptr<Thread> connect_targets(Env::Default()->StartThread(
121  {}, "ConnectTargets",
122  [&router, &done] {
123  // Sleep for a long time before calling GetOutputPorts(), to make it
124  // very likely that SetAspiredVersions() gets called first and has to
125  // block.
126  Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 /* 1 second */);
127 
128  std::vector<Source<StoragePath>*> output_ports =
129  router.GetOutputPorts();
130  ASSERT_EQ(2, output_ports.size());
131  std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
132  for (int i = 0; i < output_ports.size(); ++i) {
133  std::unique_ptr<test_util::MockStoragePathTarget> target(
134  new StrictMock<test_util::MockStoragePathTarget>);
135  EXPECT_CALL(*target, SetAspiredVersions(_, IsEmpty()));
136  ConnectSourceToTarget(output_ports[i], target.get());
137  targets.push_back(std::move(target));
138  }
139  done.WaitForNotification();
140  }));
141 
142  router.SetAspiredVersions("zero", {});
143  router.SetAspiredVersions("one", {});
144 
145  done.Notify();
146 }
147 
148 TEST(SourceRouterTest, SetAspiredVersionsBlocksUntilAllTargetsConnected_2) {
149  // Scenario 2: When SetAspiredVersions() is invoked, GetOutputPorts() has been
150  // called but only one of the two ports has been connected to a target. The
151  // SetAspiredVersions() call should block until the other port is connected.
152 
153  TestSourceRouter router;
154  std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
155  ASSERT_EQ(2, output_ports.size());
156  std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
157  for (int i = 0; i < output_ports.size(); ++i) {
158  std::unique_ptr<test_util::MockStoragePathTarget> target(
159  new StrictMock<test_util::MockStoragePathTarget>);
160  targets.push_back(std::move(target));
161  }
162 
163  // Connect target 0 now.
164  ConnectSourceToTarget(output_ports[0], targets[0].get());
165 
166  // Connect target 1 asynchronously after a long delay.
167  std::unique_ptr<Thread> connect_target_1(Env::Default()->StartThread(
168  {}, "ConnectTarget1",
169  [&output_ports, &targets] {
170  // Sleep for a long time before connecting target 1, to make it very
171  // likely that SetAspiredVersions() gets called first and has to
172  // block.
173  Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 /* 1 second */);
174  ConnectSourceToTarget(output_ports[1], targets[1].get());
175  }));
176 
177  EXPECT_CALL(*targets[1], SetAspiredVersions(Eq("one"), IsEmpty()));
178  router.SetAspiredVersions("one", {});
179 }
180 
181 TEST(SourceRouterTest, DiscardRequest) {
182  // Testing return kNoRoute to discard a request
183 
184  TestSourceRouter router;
185  std::vector<Source<StoragePath>*> output_ports = router.GetOutputPorts();
186  std::vector<std::unique_ptr<test_util::MockStoragePathTarget>> targets;
187  for (int i = 0; i < output_ports.size(); ++i) {
188  targets.emplace_back(new StrictMock<test_util::MockStoragePathTarget>);
189  ConnectSourceToTarget(output_ports[i], targets.back().get());
190  }
191 
192  router.SetAspiredVersions("no_route", {});
193 
194  // Expect no `SetAspiredVersions` call on `output_ports[0]`.
195 }
196 
197 } // namespace
198 } // namespace serving
199 } // namespace tensorflow