TensorFlow Serving C++ API Documentation
source_adapter_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/source_adapter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow_serving/core/servable_id.h"
30 #include "tensorflow_serving/core/storage_path.h"
31 #include "tensorflow_serving/core/test_util/fake_storage_path_source_adapter.h"
32 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h"
33 
34 using ::testing::ElementsAre;
35 using ::testing::Eq;
36 using ::testing::IsEmpty;
37 using ::testing::StrictMock;
38 
39 namespace tensorflow {
40 namespace serving {
41 namespace {
42 
43 // A SourceAdapter that expects all aspired-versions requests to be empty.
44 class LimitedAdapter final : public SourceAdapter<StoragePath, StoragePath> {
45  public:
46  LimitedAdapter() = default;
47  ~LimitedAdapter() override { Detach(); }
48 
49  protected:
50  std::vector<ServableData<StoragePath>> Adapt(
51  const StringPiece servable_name,
52  std::vector<ServableData<StoragePath>> versions) override {
53  CHECK(versions.empty());
54  return {};
55  }
56 
57  private:
58  TF_DISALLOW_COPY_AND_ASSIGN(LimitedAdapter);
59 };
60 
61 TEST(SourceAdapterTest, AdaptOneVersion) {
62  test_util::FakeStoragePathSourceAdapter adapter("baz");
63  ServableData<StoragePath> output =
64  adapter.AdaptOneVersion(ServableData<StoragePath>({"foo", 42}, "bar"));
65  EXPECT_EQ("foo", output.id().name);
66  EXPECT_EQ(42, output.id().version);
67  EXPECT_EQ("bar/baz", output.DataOrDie());
68 }
69 
70 TEST(SourceAdapterTest, SetAspiredVersionsBlocksUntilTargetConnected) {
71  LimitedAdapter adapter;
72  std::unique_ptr<test_util::MockStoragePathTarget> target(
73  new StrictMock<test_util::MockStoragePathTarget>);
74  std::unique_ptr<Thread> connect_target(Env::Default()->StartThread(
75  {}, "ConnectTarget",
76  [&adapter, &target] {
77  // Sleep for a long time before connecting the target, to make it very
78  // likely that SetAspiredVersions() gets called first and has to block.
79  Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 /* 1 second */);
80  ConnectSourceToTarget(&adapter, target.get());
81  }));
82  EXPECT_CALL(*target, SetAspiredVersions(Eq("foo"), IsEmpty()));
83  adapter.SetAspiredVersions("foo", {});
84 }
85 
86 TEST(UnarySourceAdapterTest, Basic) {
87  test_util::FakeStoragePathSourceAdapter adapter;
88  std::unique_ptr<test_util::MockStoragePathTarget> target(
89  new StrictMock<test_util::MockStoragePathTarget>);
90  ConnectSourceToTarget(&adapter, target.get());
91  EXPECT_CALL(
92  *target,
93  SetAspiredVersions(
94  Eq("foo"),
95  ElementsAre(
96  ServableData<StoragePath>({"foo", 0}, "mrop"),
97  ServableData<StoragePath>(
98  {"foo", 1},
99  errors::InvalidArgument(
100  "FakeStoragePathSourceAdapter Convert() dutifully "
101  "failing on \"invalid\" data")),
102  ServableData<StoragePath>({"foo", 2}, errors::Unknown("d'oh")))));
103  adapter.SetAspiredVersions(
104  "foo", {ServableData<StoragePath>({"foo", 0}, "mrop"),
105  ServableData<StoragePath>({"foo", 1}, "invalid"),
106  ServableData<StoragePath>({"foo", 2}, errors::Unknown("d'oh"))});
107 }
108 
109 TEST(ErrorInjectingSourceAdapterTest, Basic) {
110  ErrorInjectingSourceAdapter<string, string> adapter(
111  errors::Unknown("Injected error"));
112  std::unique_ptr<test_util::MockStoragePathTarget> target(
113  new StrictMock<test_util::MockStoragePathTarget>);
114  ConnectSourceToTarget(&adapter, target.get());
115  EXPECT_CALL(
116  *target,
117  SetAspiredVersions(
118  Eq("foo"),
119  ElementsAre(ServableData<StoragePath>(
120  {"foo", 0}, errors::Unknown("Injected error")),
121  ServableData<StoragePath>(
122  {"foo", 1}, errors::Unknown("Original error")))));
123  adapter.SetAspiredVersions(
124  "foo", {ServableData<StoragePath>({"foo", 0}, "mrop"),
125  ServableData<StoragePath>({"foo", 1},
126  errors::Unknown("Original error"))});
127 }
128 
129 } // namespace
130 } // namespace serving
131 } // namespace tensorflow