16 #include "tensorflow_serving/core/source_adapter.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow_serving/core/servable_id.h"
30 #include "tensorflow_serving/core/storage_path.h"
31 #include "tensorflow_serving/core/test_util/fake_storage_path_source_adapter.h"
32 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h"
34 using ::testing::ElementsAre;
36 using ::testing::IsEmpty;
37 using ::testing::StrictMock;
39 namespace tensorflow {
44 class LimitedAdapter final :
public SourceAdapter<StoragePath, StoragePath> {
46 LimitedAdapter() =
default;
47 ~LimitedAdapter()
override { Detach(); }
50 std::vector<ServableData<StoragePath>> Adapt(
51 const StringPiece servable_name,
52 std::vector<ServableData<StoragePath>> versions)
override {
53 CHECK(versions.empty());
58 TF_DISALLOW_COPY_AND_ASSIGN(LimitedAdapter);
61 TEST(SourceAdapterTest, AdaptOneVersion) {
62 test_util::FakeStoragePathSourceAdapter adapter(
"baz");
63 ServableData<StoragePath> output =
64 adapter.AdaptOneVersion(ServableData<StoragePath>({
"foo", 42},
"bar"));
65 EXPECT_EQ(
"foo", output.id().name);
66 EXPECT_EQ(42, output.id().version);
67 EXPECT_EQ(
"bar/baz", output.DataOrDie());
70 TEST(SourceAdapterTest, SetAspiredVersionsBlocksUntilTargetConnected) {
71 LimitedAdapter adapter;
72 std::unique_ptr<test_util::MockStoragePathTarget> target(
73 new StrictMock<test_util::MockStoragePathTarget>);
74 std::unique_ptr<Thread> connect_target(Env::Default()->StartThread(
79 Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 );
80 ConnectSourceToTarget(&adapter, target.get());
82 EXPECT_CALL(*target, SetAspiredVersions(Eq(
"foo"), IsEmpty()));
83 adapter.SetAspiredVersions(
"foo", {});
86 TEST(UnarySourceAdapterTest, Basic) {
87 test_util::FakeStoragePathSourceAdapter adapter;
88 std::unique_ptr<test_util::MockStoragePathTarget> target(
89 new StrictMock<test_util::MockStoragePathTarget>);
90 ConnectSourceToTarget(&adapter, target.get());
96 ServableData<StoragePath>({
"foo", 0},
"mrop"),
97 ServableData<StoragePath>(
99 errors::InvalidArgument(
100 "FakeStoragePathSourceAdapter Convert() dutifully "
101 "failing on \"invalid\" data")),
102 ServableData<StoragePath>({
"foo", 2}, errors::Unknown(
"d'oh")))));
103 adapter.SetAspiredVersions(
104 "foo", {ServableData<StoragePath>({
"foo", 0},
"mrop"),
105 ServableData<StoragePath>({
"foo", 1},
"invalid"),
106 ServableData<StoragePath>({
"foo", 2}, errors::Unknown(
"d'oh"))});
109 TEST(ErrorInjectingSourceAdapterTest, Basic) {
110 ErrorInjectingSourceAdapter<string, string> adapter(
111 errors::Unknown(
"Injected error"));
112 std::unique_ptr<test_util::MockStoragePathTarget> target(
113 new StrictMock<test_util::MockStoragePathTarget>);
114 ConnectSourceToTarget(&adapter, target.get());
119 ElementsAre(ServableData<StoragePath>(
120 {
"foo", 0}, errors::Unknown(
"Injected error")),
121 ServableData<StoragePath>(
122 {
"foo", 1}, errors::Unknown(
"Original error")))));
123 adapter.SetAspiredVersions(
124 "foo", {ServableData<StoragePath>({
"foo", 0},
"mrop"),
125 ServableData<StoragePath>({
"foo", 1},
126 errors::Unknown(
"Original error"))});