TensorFlow Serving C++ API Documentation
tfrt_get_model_metadata_impl_test.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/map.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "xla/tsl/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tsl/platform/casts.h"
33 #include "tensorflow_serving/apis/model.pb.h"
34 #include "tensorflow_serving/config/model_server_config.pb.h"
35 #include "tensorflow_serving/config/platform_config.pb.h"
36 #include "tensorflow_serving/core/aspired_version_policy.h"
37 #include "tensorflow_serving/core/availability_preserving_policy.h"
38 #include "tensorflow_serving/core/servable_handle.h"
39 #include "tensorflow_serving/model_servers/model_platform_types.h"
40 #include "tensorflow_serving/model_servers/platform_config_util.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
43 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
44 #include "tensorflow_serving/test_util/test_util.h"
45 
46 namespace tensorflow {
47 namespace serving {
48 namespace {
49 
50 constexpr char kTestModelName[] = "test_model";
51 constexpr int kTestModelVersion = 123;
52 constexpr absl::string_view kSignatureDef = "signature_def";
53 
54 class TFRTGetModelMetadataImplTest : public ::testing::Test {
55  public:
56  static void SetUpTestSuite() {
57  tfrt_stub::SetGlobalRuntime(
58  tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
59 
60  const string saved_model_path = test_util::TensorflowTestSrcDirPath(
61  "cc/saved_model/testdata/half_plus_two");
62  TF_ASSERT_OK(
63  CreateServerCore(saved_model_path, true, &saved_model_server_core_));
64  }
65 
66  static void TearDownTestSuite() { saved_model_server_core_.reset(); }
67 
68  protected:
69  static Status CreateServerCore(const string& model_path,
70  bool saved_model_on_disk,
71  std::unique_ptr<ServerCore>* server_core) {
72  ModelServerConfig config;
73  auto model_config = config.mutable_model_config_list()->add_config();
74  model_config->set_name(kTestModelName);
75  model_config->set_base_path(model_path);
76  model_config->set_model_platform(kTensorFlowModelPlatform);
77 
78  // For ServerCore Options, we leave servable_state_monitor_creator
79  // unspecified so the default servable_state_monitor_creator will be used.
80  ServerCore::Options options;
81  options.model_server_config = config;
82  PlatformConfigMap platform_config_map;
83  ::google::protobuf::Any source_adapter_config;
84  TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
85  source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
86  (*(*platform_config_map
87  .mutable_platform_configs())[kTensorFlowModelPlatform]
88  .mutable_source_adapter_config()) = source_adapter_config;
89  options.platform_config_map = platform_config_map;
90  options.aspired_version_policy =
91  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
92  // Reduce the number of initial load threads to be num_load_threads to avoid
93  // timing out in tests.
94  options.num_initial_load_threads = options.num_load_threads;
95  return ServerCore::Create(std::move(options), server_core);
96  }
97 
98  ServerCore* GetServerCore() { return saved_model_server_core_.get(); }
99 
100  private:
101  static std::unique_ptr<ServerCore> saved_model_server_core_;
102 };
103 
104 std::unique_ptr<ServerCore>
105  TFRTGetModelMetadataImplTest::saved_model_server_core_;
106 
107 SignatureDefMap GetSignatureDefMap(ServerCore* server_core,
108  const ModelSpec& model_spec) {
109  SignatureDefMap signature_def_map;
110  ServableHandle<Servable> servable;
111  TF_EXPECT_OK(server_core->GetServableHandle(model_spec, &servable));
112  auto& saved_model =
113  down_cast<TfrtSavedModelServable*>(servable.get())->saved_model();
114  for (const auto& signature : saved_model.GetMetaGraphDef().signature_def()) {
115  (*signature_def_map.mutable_signature_def())[signature.first] =
116  signature.second;
117  }
118  return signature_def_map;
119 }
120 
121 TEST_F(TFRTGetModelMetadataImplTest, EmptyOrInvalidMetadataFieldList) {
122  GetModelMetadataRequest request;
123  GetModelMetadataResponse response;
124 
125  // Empty metadata field list is invalid.
126  EXPECT_EQ(absl::StatusCode::kInvalidArgument,
127  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
128  &response)
129  .code());
130  request.add_metadata_field("some_stuff");
131 
132  // Field enum is outside of valid range.
133  EXPECT_EQ(absl::StatusCode::kInvalidArgument,
134  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
135  &response)
136  .code());
137 }
138 
139 TEST_F(TFRTGetModelMetadataImplTest, MissingOrEmptyModelSpec) {
140  GetModelMetadataRequest request;
141  GetModelMetadataResponse response;
142 
143  request.add_metadata_field(std::string(kSignatureDef));
144  EXPECT_EQ(absl::StatusCode::kInvalidArgument,
145  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
146  &response)
147  .code());
148 
149  ModelSpec* model_spec = request.mutable_model_spec();
150  model_spec->clear_name();
151 
152  // Model name is not specified.
153  EXPECT_EQ(absl::StatusCode::kInvalidArgument,
154  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
155  &response)
156  .code());
157 
158  // Model name is wrong, not found.
159  model_spec->set_name("test");
160  EXPECT_EQ(tensorflow::error::NOT_FOUND,
161  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
162  &response)
163  .code());
164 }
165 
166 TEST_F(TFRTGetModelMetadataImplTest, ReturnsSignaturesForValidModel) {
167  GetModelMetadataRequest request;
168  GetModelMetadataResponse response;
169 
170  ModelSpec* model_spec = request.mutable_model_spec();
171  model_spec->set_name(kTestModelName);
172  model_spec->mutable_version()->set_value(kTestModelVersion);
173  request.add_metadata_field(std::string(kSignatureDef));
174 
175  TF_EXPECT_OK(TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(),
176  request, &response));
177  EXPECT_THAT(response.model_spec(),
178  test_util::EqualsProto(request.model_spec()));
179  EXPECT_EQ(response.metadata_size(), 1);
180  SignatureDefMap received_signature_def_map;
181  response.metadata().at(kSignatureDef).UnpackTo(&received_signature_def_map);
182 
183  SignatureDefMap expected_signature_def_map =
184  GetSignatureDefMap(GetServerCore(), request.model_spec());
185  EXPECT_THAT(response.model_spec(),
186  test_util::EqualsProto(request.model_spec()));
187 
188  EXPECT_EQ(expected_signature_def_map.signature_def().size(),
189  received_signature_def_map.signature_def().size());
190  EXPECT_THAT(
191  expected_signature_def_map.signature_def().at("regress_x_to_y"),
192  test_util::EqualsProto(
193  received_signature_def_map.signature_def().at("regress_x_to_y")));
194  EXPECT_THAT(
195  expected_signature_def_map.signature_def().at(
196  kDefaultServingSignatureDefKey),
197  test_util::EqualsProto(received_signature_def_map.signature_def().at(
198  kDefaultServingSignatureDefKey)));
199 }
200 
201 // Verifies that GetModelMetadataWithModelSpec() uses the model spec override
202 // rather than the one in the request.
203 TEST_F(TFRTGetModelMetadataImplTest, ModelSpecOverride) {
204  auto request = test_util::CreateProto<GetModelMetadataRequest>(
205  "model_spec {"
206  " name: \"test_model\""
207  "}");
208  request.add_metadata_field(std::string(kSignatureDef));
209  auto model_spec_override =
210  test_util::CreateProto<ModelSpec>("name: \"nonexistent_model\"");
211 
212  GetModelMetadataResponse response;
213  EXPECT_NE(tensorflow::error::NOT_FOUND,
214  TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
215  &response)
216  .code());
217  EXPECT_EQ(tensorflow::error::NOT_FOUND,
218  TFRTGetModelMetadataImpl::GetModelMetadataWithModelSpec(
219  GetServerCore(), model_spec_override, request, &response)
220  .code());
221 }
222 
223 } // namespace
224 } // namespace serving
225 } // namespace tensorflow
static Status Create(Options options, std::unique_ptr< ServerCore > *core)
Definition: server_core.cc:231