TensorFlow Serving C++ API Documentation
manager_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/manager.h"
17 
18 #include <map>
19 #include <memory>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow_serving/core/test_util/servable_handle_test_util.h"
25 #include "tensorflow_serving/util/any_ptr.h"
26 
27 namespace tensorflow {
28 namespace serving {
29 namespace {
30 
31 struct TestServable {
32  int member = 7;
33 };
34 
35 class TestHandle : public UntypedServableHandle {
36  public:
37  AnyPtr servable() override { return &servable_; }
38 
39  const ServableId& id() const override { return id_; }
40 
41  private:
42  const ServableId id_ = {"servable", 7};
43  TestServable servable_;
44 };
45 
46 // A manager that a returns a TestHandle.
47 class TestManager : public Manager {
48  public:
49  std::vector<ServableId> ListAvailableServableIds() const override {
50  LOG(FATAL) << "Not expected to be called.";
51  }
52 
53  private:
54  Status GetUntypedServableHandle(
55  const ServableRequest& request,
56  std::unique_ptr<UntypedServableHandle>* result) override {
57  result->reset(new TestHandle);
58  return OkStatus();
59  }
60 
61  std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
62  GetAvailableUntypedServableHandles() const override {
63  std::map<ServableId, std::unique_ptr<UntypedServableHandle>> handles;
64  handles.emplace(ServableId{"Foo", 2},
65  std::unique_ptr<UntypedServableHandle>(new TestHandle));
66  return handles;
67  }
68 };
69 
70 TEST(ManagerTest, NoErrors) {
71  TestManager manager;
72  ServableHandle<TestServable> handle;
73  EXPECT_TRUE(manager.GetServableHandle({"Foo", 2}, &handle).ok());
74  EXPECT_NE(nullptr, handle.get());
75 }
76 
77 TEST(ManagerTest, TypeError) {
78  TestManager manager;
79  ServableHandle<int> handle;
80  EXPECT_FALSE(manager.GetServableHandle({"Foo", 2}, &handle).ok());
81  EXPECT_EQ(nullptr, handle.get());
82 }
83 
84 TEST(ManagerTest, GetAvailableServableHandles) {
85  TestManager manager;
86  const std::map<ServableId, ServableHandle<TestServable>> handles =
87  manager.GetAvailableServableHandles<TestServable>();
88  ASSERT_EQ(1, handles.size());
89  for (const auto& handle : handles) {
90  EXPECT_EQ((ServableId{"Foo", 2}), handle.first);
91  EXPECT_EQ(7, handle.second->member);
92  }
93 }
94 
95 TEST(ManagerTest, GetAvailableServableHandlesWrongType) {
96  TestManager manager;
97  const std::map<ServableId, ServableHandle<int>> handles =
98  manager.GetAvailableServableHandles<int>();
99  EXPECT_EQ(0, handles.size());
100 }
101 
102 // A manager that returns OK even though the result is null. This behavior
103 // violates the interface of Manager, but it is used to test that this violation
104 // is handled gracefully rather than a crash or memory corruption.
105 class ReturnNullManager : public TestManager {
106  private:
107  Status GetUntypedServableHandle(
108  const ServableRequest& request,
109  std::unique_ptr<UntypedServableHandle>* result) override {
110  *result = nullptr;
111  return OkStatus();
112  }
113 };
114 
115 TEST(ManagerTest, NullHandleReturnsError) {
116  ReturnNullManager manager;
117  ServableHandle<TestServable> handle;
118  EXPECT_FALSE(manager.GetServableHandle({"Foo", 2}, &handle).ok());
119  EXPECT_EQ(nullptr, handle.get());
120 }
121 
122 // A manager that returns an error even though the result is non-null.
123 class ReturnErrorManager : public TestManager {
124  private:
125  Status GetUntypedServableHandle(
126  const ServableRequest& request,
127  std::unique_ptr<UntypedServableHandle>* result) override {
128  result->reset(new TestHandle);
129  return errors::Internal("Something bad happened.");
130  }
131 };
132 
133 TEST(ManagerTest, ErrorReturnsNullHandle) {
134  ReturnErrorManager manager;
135  ServableHandle<TestServable> handle;
136  EXPECT_FALSE(manager.GetServableHandle({"Foo", 2}, &handle).ok());
137  EXPECT_EQ(nullptr, handle.get());
138 }
139 
140 TEST(ServableHandleTest, PointerOps) {
141  TestServable servables[2];
142  ServableHandle<TestServable> handles[2];
143 
144  const ServableId id = {"servable", 7};
145  handles[0] = test_util::WrapAsHandle(id, &servables[0]);
146  handles[1] = test_util::WrapAsHandle(id, &servables[1]);
147 
148  // Equality.
149  EXPECT_EQ(handles[0], handles[0]);
150 
151  // Inequality.
152  EXPECT_NE(handles[0], handles[1]);
153 
154  // Bool conversion.
155  EXPECT_TRUE(handles[0]);
156 
157  // Dereference and get.
158  EXPECT_EQ(&servables[0], handles[0].get());
159  EXPECT_EQ(&servables[0], &*handles[0]);
160  EXPECT_EQ(&servables[0].member, &handles[0]->member);
161 }
162 
163 TEST(ServableHandleTest, Id) {
164  TestServable servables[2];
165  ServableHandle<TestServable> handles[2];
166 
167  const ServableId id = {"servable", 7};
168  handles[0] = test_util::WrapAsHandle(id, &servables[0]);
169  handles[1] = test_util::WrapAsHandle(id, &servables[1]);
170 
171  EXPECT_EQ(id, handles[0].id());
172  EXPECT_EQ(id, handles[1].id());
173 }
174 
175 TEST(ServableRequestTest, Specific) {
176  const auto request = ServableRequest::Specific("servable", 7);
177  EXPECT_EQ("servable", request.name);
178  EXPECT_EQ(7, *request.version);
179 }
180 
181 TEST(ServableRequestTest, Latest) {
182  const auto request = ServableRequest::Latest("servable");
183  EXPECT_EQ("servable", request.name);
184  EXPECT_FALSE(request.version);
185 }
186 
187 TEST(ServableRequestTest, FromId) {
188  const auto request = ServableRequest::FromId({"servable", 7});
189  EXPECT_EQ("servable", request.name);
190  EXPECT_EQ(7, *request.version);
191 }
192 
193 } // namespace
194 } // namespace serving
195 } // namespace tensorflow