16 #include "tensorflow_serving/core/source_adapter.h" 
   22 #include <gmock/gmock.h> 
   23 #include <gtest/gtest.h> 
   24 #include "tensorflow/core/lib/core/errors.h" 
   25 #include "tensorflow/core/lib/core/status.h" 
   26 #include "tensorflow/core/lib/strings/strcat.h" 
   27 #include "tensorflow/core/platform/env.h" 
   28 #include "tensorflow/core/platform/macros.h" 
   29 #include "tensorflow_serving/core/servable_id.h" 
   30 #include "tensorflow_serving/core/storage_path.h" 
   31 #include "tensorflow_serving/core/test_util/fake_storage_path_source_adapter.h" 
   32 #include "tensorflow_serving/core/test_util/mock_storage_path_target.h" 
   34 using ::testing::ElementsAre;
 
   36 using ::testing::IsEmpty;
 
   37 using ::testing::StrictMock;
 
   39 namespace tensorflow {
 
   44 class LimitedAdapter final : 
public SourceAdapter<StoragePath, StoragePath> {
 
   46   LimitedAdapter() = 
default;
 
   47   ~LimitedAdapter()
 override { Detach(); }
 
   50   std::vector<ServableData<StoragePath>> Adapt(
 
   51       const StringPiece servable_name,
 
   52       std::vector<ServableData<StoragePath>> versions)
 override {
 
   53     CHECK(versions.empty());
 
   58   TF_DISALLOW_COPY_AND_ASSIGN(LimitedAdapter);
 
   61 TEST(SourceAdapterTest, AdaptOneVersion) {
 
   62   test_util::FakeStoragePathSourceAdapter adapter(
"baz");
 
   63   ServableData<StoragePath> output =
 
   64       adapter.AdaptOneVersion(ServableData<StoragePath>({
"foo", 42}, 
"bar"));
 
   65   EXPECT_EQ(
"foo", output.id().name);
 
   66   EXPECT_EQ(42, output.id().version);
 
   67   EXPECT_EQ(
"bar/baz", output.DataOrDie());
 
   70 TEST(SourceAdapterTest, SetAspiredVersionsBlocksUntilTargetConnected) {
 
   71   LimitedAdapter adapter;
 
   72   std::unique_ptr<test_util::MockStoragePathTarget> target(
 
   73       new StrictMock<test_util::MockStoragePathTarget>);
 
   74   std::unique_ptr<Thread> connect_target(Env::Default()->StartThread(
 
   79         Env::Default()->SleepForMicroseconds(1 * 1000 * 1000 );
 
   80         ConnectSourceToTarget(&adapter, target.get());
 
   82   EXPECT_CALL(*target, SetAspiredVersions(Eq(
"foo"), IsEmpty()));
 
   83   adapter.SetAspiredVersions(
"foo", {});
 
   86 TEST(UnarySourceAdapterTest, Basic) {
 
   87   test_util::FakeStoragePathSourceAdapter adapter;
 
   88   std::unique_ptr<test_util::MockStoragePathTarget> target(
 
   89       new StrictMock<test_util::MockStoragePathTarget>);
 
   90   ConnectSourceToTarget(&adapter, target.get());
 
   96               ServableData<StoragePath>({
"foo", 0}, 
"mrop"),
 
   97               ServableData<StoragePath>(
 
   99                   errors::InvalidArgument(
 
  100                       "FakeStoragePathSourceAdapter Convert() dutifully " 
  101                       "failing on \"invalid\" data")),
 
  102               ServableData<StoragePath>({
"foo", 2}, errors::Unknown(
"d'oh")))));
 
  103   adapter.SetAspiredVersions(
 
  104       "foo", {ServableData<StoragePath>({
"foo", 0}, 
"mrop"),
 
  105               ServableData<StoragePath>({
"foo", 1}, 
"invalid"),
 
  106               ServableData<StoragePath>({
"foo", 2}, errors::Unknown(
"d'oh"))});
 
  109 TEST(ErrorInjectingSourceAdapterTest, Basic) {
 
  110   ErrorInjectingSourceAdapter<string, string> adapter(
 
  111       errors::Unknown(
"Injected error"));
 
  112   std::unique_ptr<test_util::MockStoragePathTarget> target(
 
  113       new StrictMock<test_util::MockStoragePathTarget>);
 
  114   ConnectSourceToTarget(&adapter, target.get());
 
  119           ElementsAre(ServableData<StoragePath>(
 
  120                           {
"foo", 0}, errors::Unknown(
"Injected error")),
 
  121                       ServableData<StoragePath>(
 
  122                           {
"foo", 1}, errors::Unknown(
"Original error")))));
 
  123   adapter.SetAspiredVersions(
 
  124       "foo", {ServableData<StoragePath>({
"foo", 0}, 
"mrop"),
 
  125               ServableData<StoragePath>({
"foo", 1},
 
  126                                         errors::Unknown(
"Original error"))});