TensorFlow Serving C++ API Documentation
get_model_metadata_impl_test.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/map.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/cc/saved_model/loader.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tensorflow_serving/apis/model.pb.h"
33 #include "tensorflow_serving/config/model_server_config.pb.h"
34 #include "tensorflow_serving/config/platform_config.pb.h"
35 #include "tensorflow_serving/core/aspired_version_policy.h"
36 #include "tensorflow_serving/core/availability_preserving_policy.h"
37 #include "tensorflow_serving/core/servable_handle.h"
38 #include "tensorflow_serving/model_servers/model_platform_types.h"
39 #include "tensorflow_serving/model_servers/platform_config_util.h"
40 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42 #include "tensorflow_serving/test_util/test_util.h"
43 #include "tensorflow_serving/util/oss_or_google.h"
44 
45 namespace tensorflow {
46 namespace serving {
47 namespace {
48 
49 constexpr char kTestModelName[] = "test_model";
50 constexpr int kTestModelVersion = 123;
51 
52 class GetModelMetadataImplTest : public ::testing::TestWithParam<bool> {
53  public:
54  static void SetUpTestSuite() {
55  if (!IsTensorflowServingOSS()) {
56  const string session_bundle_path = test_util::TestSrcDirPath(
57  "/servables/tensorflow/google/testdata/half_plus_two");
58  TF_ASSERT_OK(CreateServerCore(session_bundle_path, false, &server_core_));
59  }
60 
61  const string saved_model_path = test_util::TensorflowTestSrcDirPath(
62  "cc/saved_model/testdata/half_plus_two");
63  TF_ASSERT_OK(
64  CreateServerCore(saved_model_path, true, &saved_model_server_core_));
65  }
66 
67  static void TearDownTestSuite() {
68  server_core_.reset();
69  saved_model_server_core_.reset();
70  }
71 
72  protected:
73  static Status CreateServerCore(const string& model_path,
74  bool saved_model_on_disk,
75  std::unique_ptr<ServerCore>* server_core) {
76  ModelServerConfig config;
77  auto model_config = config.mutable_model_config_list()->add_config();
78  model_config->set_name(kTestModelName);
79  model_config->set_base_path(model_path);
80  model_config->set_model_platform(kTensorFlowModelPlatform);
81 
82  // For ServerCore Options, we leave servable_state_monitor_creator
83  // unspecified so the default servable_state_monitor_creator will be used.
84  ServerCore::Options options;
85  options.model_server_config = config;
86  options.platform_config_map =
87  CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
88  options.aspired_version_policy =
89  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
90  // Reduce the number of initial load threads to be num_load_threads to avoid
91  // timing out in tests.
92  options.num_initial_load_threads = options.num_load_threads;
93  return ServerCore::Create(std::move(options), server_core);
94  }
95 
96  ServerCore* GetServerCore() {
97  if (GetParam()) {
98  return saved_model_server_core_.get();
99  }
100  return server_core_.get();
101  }
102 
103  private:
104  static std::unique_ptr<ServerCore> server_core_;
105  static std::unique_ptr<ServerCore> saved_model_server_core_;
106 };
107 
108 std::unique_ptr<ServerCore> GetModelMetadataImplTest::server_core_;
109 std::unique_ptr<ServerCore> GetModelMetadataImplTest::saved_model_server_core_;
110 
111 SignatureDefMap GetSignatureDefMap(ServerCore* server_core,
112  const ModelSpec& model_spec) {
113  SignatureDefMap signature_def_map;
114  ServableHandle<SavedModelBundle> bundle;
115  TF_EXPECT_OK(server_core->GetServableHandle(model_spec, &bundle));
116  for (const auto& signature : bundle->meta_graph_def.signature_def()) {
117  (*signature_def_map.mutable_signature_def())[signature.first] =
118  signature.second;
119  }
120  return signature_def_map;
121 }
122 
123 TEST_P(GetModelMetadataImplTest, EmptyOrInvalidMetadataFieldList) {
124  GetModelMetadataRequest request;
125  GetModelMetadataResponse response;
126 
127  // Empty metadata field list is invalid.
128  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
129  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
130  &response)
131  .code());
132  request.add_metadata_field("some_stuff");
133 
134  // Field enum is outside of valid range.
135  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
136  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
137  &response)
138  .code());
139 }
140 
141 TEST_P(GetModelMetadataImplTest, MissingOrEmptyModelSpec) {
142  GetModelMetadataRequest request;
143  GetModelMetadataResponse response;
144 
145  request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
146  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
147  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
148  &response)
149  .code());
150 
151  ModelSpec* model_spec = request.mutable_model_spec();
152  model_spec->clear_name();
153 
154  // Model name is not specified.
155  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
156  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
157  &response)
158  .code());
159 
160  // Model name is wrong, not found.
161  model_spec->set_name("test");
162  EXPECT_EQ(tensorflow::error::NOT_FOUND,
163  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
164  &response)
165  .code());
166 }
167 
168 TEST_P(GetModelMetadataImplTest, ReturnsSignaturesForValidModel) {
169  GetModelMetadataRequest request;
170  GetModelMetadataResponse response;
171 
172  ModelSpec* model_spec = request.mutable_model_spec();
173  model_spec->set_name(kTestModelName);
174  model_spec->mutable_version()->set_value(kTestModelVersion);
175  request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
176 
177  TF_EXPECT_OK(GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
178  &response));
179  EXPECT_THAT(response.model_spec(),
180  test_util::EqualsProto(request.model_spec()));
181  EXPECT_EQ(response.metadata_size(), 1);
182  SignatureDefMap received_signature_def_map;
183  response.metadata()
184  .at(GetModelMetadataImpl::kSignatureDef)
185  .UnpackTo(&received_signature_def_map);
186 
187  SignatureDefMap expected_signature_def_map =
188  GetSignatureDefMap(GetServerCore(), request.model_spec());
189  EXPECT_THAT(response.model_spec(),
190  test_util::EqualsProto(request.model_spec()));
191 
192  EXPECT_EQ(expected_signature_def_map.signature_def().size(),
193  received_signature_def_map.signature_def().size());
194  if (GetParam()) {
195  EXPECT_THAT(
196  expected_signature_def_map.signature_def().at("regress_x_to_y"),
197  test_util::EqualsProto(
198  received_signature_def_map.signature_def().at("regress_x_to_y")));
199  } else {
200  EXPECT_THAT(expected_signature_def_map.signature_def().at("regress"),
201  test_util::EqualsProto(
202  received_signature_def_map.signature_def().at("regress")));
203  }
204  EXPECT_THAT(
205  expected_signature_def_map.signature_def().at(
206  kDefaultServingSignatureDefKey),
207  test_util::EqualsProto(received_signature_def_map.signature_def().at(
208  kDefaultServingSignatureDefKey)));
209 }
210 
211 // Verifies that GetModelMetadataWithModelSpec() uses the model spec override
212 // rather than the one in the request.
213 TEST_P(GetModelMetadataImplTest, ModelSpecOverride) {
214  auto request = test_util::CreateProto<GetModelMetadataRequest>(
215  "model_spec {"
216  " name: \"test_model\""
217  "}");
218  request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
219  auto model_spec_override =
220  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
221 
222  GetModelMetadataResponse response;
223  EXPECT_NE(tensorflow::error::NOT_FOUND,
224  GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
225  &response)
226  .code());
227  EXPECT_EQ(tensorflow::error::NOT_FOUND,
228  GetModelMetadataImpl::GetModelMetadataWithModelSpec(
229  GetServerCore(), model_spec_override, request, &response)
230  .code());
231 }
232 
233 // Test all ClassifierTest test cases with both SessionBundle and SavedModel.
234 INSTANTIATE_TEST_CASE_P(UseSavedModel, GetModelMetadataImplTest,
235  IsTensorflowServingOSS() ? ::testing::Values(true)
236  : ::testing::Bool());
237 
238 } // namespace
239 } // namespace serving
240 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231