16 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/map.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/cc/saved_model/loader.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tensorflow_serving/apis/model.pb.h"
33 #include "tensorflow_serving/config/model_server_config.pb.h"
34 #include "tensorflow_serving/config/platform_config.pb.h"
35 #include "tensorflow_serving/core/aspired_version_policy.h"
36 #include "tensorflow_serving/core/availability_preserving_policy.h"
37 #include "tensorflow_serving/core/servable_handle.h"
38 #include "tensorflow_serving/model_servers/model_platform_types.h"
39 #include "tensorflow_serving/model_servers/platform_config_util.h"
40 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42 #include "tensorflow_serving/test_util/test_util.h"
43 #include "tensorflow_serving/util/oss_or_google.h"
45 namespace tensorflow {
49 constexpr
char kTestModelName[] =
"test_model";
50 constexpr
int kTestModelVersion = 123;
52 class GetModelMetadataImplTest :
public ::testing::TestWithParam<bool> {
54 static void SetUpTestSuite() {
55 if (!IsTensorflowServingOSS()) {
56 const string session_bundle_path = test_util::TestSrcDirPath(
57 "/servables/tensorflow/google/testdata/half_plus_two");
58 TF_ASSERT_OK(CreateServerCore(session_bundle_path,
false, &server_core_));
61 const string saved_model_path = test_util::TensorflowTestSrcDirPath(
62 "cc/saved_model/testdata/half_plus_two");
64 CreateServerCore(saved_model_path,
true, &saved_model_server_core_));
67 static void TearDownTestSuite() {
69 saved_model_server_core_.reset();
73 static Status CreateServerCore(
const string& model_path,
74 bool saved_model_on_disk,
75 std::unique_ptr<ServerCore>* server_core) {
76 ModelServerConfig config;
77 auto model_config = config.mutable_model_config_list()->add_config();
78 model_config->set_name(kTestModelName);
79 model_config->set_base_path(model_path);
80 model_config->set_model_platform(kTensorFlowModelPlatform);
84 ServerCore::Options options;
85 options.model_server_config = config;
86 options.platform_config_map =
87 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
88 options.aspired_version_policy =
89 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
92 options.num_initial_load_threads = options.num_load_threads;
96 ServerCore* GetServerCore() {
98 return saved_model_server_core_.get();
100 return server_core_.get();
104 static std::unique_ptr<ServerCore> server_core_;
105 static std::unique_ptr<ServerCore> saved_model_server_core_;
108 std::unique_ptr<ServerCore> GetModelMetadataImplTest::server_core_;
109 std::unique_ptr<ServerCore> GetModelMetadataImplTest::saved_model_server_core_;
111 SignatureDefMap GetSignatureDefMap(ServerCore* server_core,
112 const ModelSpec& model_spec) {
113 SignatureDefMap signature_def_map;
114 ServableHandle<SavedModelBundle> bundle;
115 TF_EXPECT_OK(server_core->GetServableHandle(model_spec, &bundle));
116 for (
const auto& signature : bundle->meta_graph_def.signature_def()) {
117 (*signature_def_map.mutable_signature_def())[signature.first] =
120 return signature_def_map;
123 TEST_P(GetModelMetadataImplTest, EmptyOrInvalidMetadataFieldList) {
124 GetModelMetadataRequest request;
125 GetModelMetadataResponse response;
128 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
129 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
132 request.add_metadata_field(
"some_stuff");
135 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
136 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
141 TEST_P(GetModelMetadataImplTest, MissingOrEmptyModelSpec) {
142 GetModelMetadataRequest request;
143 GetModelMetadataResponse response;
145 request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
146 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
147 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
151 ModelSpec* model_spec = request.mutable_model_spec();
152 model_spec->clear_name();
155 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
156 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
161 model_spec->set_name(
"test");
162 EXPECT_EQ(tensorflow::error::NOT_FOUND,
163 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
168 TEST_P(GetModelMetadataImplTest, ReturnsSignaturesForValidModel) {
169 GetModelMetadataRequest request;
170 GetModelMetadataResponse response;
172 ModelSpec* model_spec = request.mutable_model_spec();
173 model_spec->set_name(kTestModelName);
174 model_spec->mutable_version()->set_value(kTestModelVersion);
175 request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
177 TF_EXPECT_OK(GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
179 EXPECT_THAT(response.model_spec(),
180 test_util::EqualsProto(request.model_spec()));
181 EXPECT_EQ(response.metadata_size(), 1);
182 SignatureDefMap received_signature_def_map;
184 .at(GetModelMetadataImpl::kSignatureDef)
185 .UnpackTo(&received_signature_def_map);
187 SignatureDefMap expected_signature_def_map =
188 GetSignatureDefMap(GetServerCore(), request.model_spec());
189 EXPECT_THAT(response.model_spec(),
190 test_util::EqualsProto(request.model_spec()));
192 EXPECT_EQ(expected_signature_def_map.signature_def().size(),
193 received_signature_def_map.signature_def().size());
196 expected_signature_def_map.signature_def().at(
"regress_x_to_y"),
197 test_util::EqualsProto(
198 received_signature_def_map.signature_def().at(
"regress_x_to_y")));
200 EXPECT_THAT(expected_signature_def_map.signature_def().at(
"regress"),
201 test_util::EqualsProto(
202 received_signature_def_map.signature_def().at(
"regress")));
205 expected_signature_def_map.signature_def().at(
206 kDefaultServingSignatureDefKey),
207 test_util::EqualsProto(received_signature_def_map.signature_def().at(
208 kDefaultServingSignatureDefKey)));
213 TEST_P(GetModelMetadataImplTest, ModelSpecOverride) {
214 auto request = test_util::CreateProto<GetModelMetadataRequest>(
216 " name: \"test_model\""
218 request.add_metadata_field(GetModelMetadataImpl::kSignatureDef);
219 auto model_spec_override =
220 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
222 GetModelMetadataResponse response;
223 EXPECT_NE(tensorflow::error::NOT_FOUND,
224 GetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
227 EXPECT_EQ(tensorflow::error::NOT_FOUND,
228 GetModelMetadataImpl::GetModelMetadataWithModelSpec(
229 GetServerCore(), model_spec_override, request, &response)
234 INSTANTIATE_TEST_CASE_P(UseSavedModel, GetModelMetadataImplTest,
235 IsTensorflowServingOSS() ? ::testing::Values(
true)
236 : ::testing::Bool());
static Status Create(Options options, std::unique_ptr< ServerCore > *core)