16 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/map.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/cc/saved_model/loader.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "xla/tsl/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tsl/platform/casts.h"
33 #include "tensorflow_serving/apis/model.pb.h"
34 #include "tensorflow_serving/config/model_server_config.pb.h"
35 #include "tensorflow_serving/config/platform_config.pb.h"
36 #include "tensorflow_serving/core/aspired_version_policy.h"
37 #include "tensorflow_serving/core/availability_preserving_policy.h"
38 #include "tensorflow_serving/core/servable_handle.h"
39 #include "tensorflow_serving/model_servers/model_platform_types.h"
40 #include "tensorflow_serving/model_servers/platform_config_util.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
43 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
44 #include "tensorflow_serving/test_util/test_util.h"
46 namespace tensorflow {
50 constexpr
char kTestModelName[] =
"test_model";
51 constexpr
int kTestModelVersion = 123;
52 constexpr absl::string_view kSignatureDef =
"signature_def";
54 class TFRTGetModelMetadataImplTest :
public ::testing::Test {
56 static void SetUpTestSuite() {
57 tfrt_stub::SetGlobalRuntime(
58 tfrt_stub::Runtime::Create(4));
60 const string saved_model_path = test_util::TensorflowTestSrcDirPath(
61 "cc/saved_model/testdata/half_plus_two");
63 CreateServerCore(saved_model_path,
true, &saved_model_server_core_));
66 static void TearDownTestSuite() { saved_model_server_core_.reset(); }
69 static Status CreateServerCore(
const string& model_path,
70 bool saved_model_on_disk,
71 std::unique_ptr<ServerCore>* server_core) {
72 ModelServerConfig config;
73 auto model_config = config.mutable_model_config_list()->add_config();
74 model_config->set_name(kTestModelName);
75 model_config->set_base_path(model_path);
76 model_config->set_model_platform(kTensorFlowModelPlatform);
80 ServerCore::Options options;
81 options.model_server_config = config;
82 PlatformConfigMap platform_config_map;
83 ::google::protobuf::Any source_adapter_config;
84 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
85 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
86 (*(*platform_config_map
87 .mutable_platform_configs())[kTensorFlowModelPlatform]
88 .mutable_source_adapter_config()) = source_adapter_config;
89 options.platform_config_map = platform_config_map;
90 options.aspired_version_policy =
91 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
94 options.num_initial_load_threads = options.num_load_threads;
98 ServerCore* GetServerCore() {
return saved_model_server_core_.get(); }
101 static std::unique_ptr<ServerCore> saved_model_server_core_;
104 std::unique_ptr<ServerCore>
105 TFRTGetModelMetadataImplTest::saved_model_server_core_;
107 SignatureDefMap GetSignatureDefMap(ServerCore* server_core,
108 const ModelSpec& model_spec) {
109 SignatureDefMap signature_def_map;
110 ServableHandle<Servable> servable;
111 TF_EXPECT_OK(server_core->GetServableHandle(model_spec, &servable));
113 down_cast<TfrtSavedModelServable*>(servable.get())->saved_model();
114 for (
const auto& signature : saved_model.GetMetaGraphDef().signature_def()) {
115 (*signature_def_map.mutable_signature_def())[signature.first] =
118 return signature_def_map;
121 TEST_F(TFRTGetModelMetadataImplTest, EmptyOrInvalidMetadataFieldList) {
122 GetModelMetadataRequest request;
123 GetModelMetadataResponse response;
126 EXPECT_EQ(absl::StatusCode::kInvalidArgument,
127 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
130 request.add_metadata_field(
"some_stuff");
133 EXPECT_EQ(absl::StatusCode::kInvalidArgument,
134 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
139 TEST_F(TFRTGetModelMetadataImplTest, MissingOrEmptyModelSpec) {
140 GetModelMetadataRequest request;
141 GetModelMetadataResponse response;
143 request.add_metadata_field(std::string(kSignatureDef));
144 EXPECT_EQ(absl::StatusCode::kInvalidArgument,
145 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
149 ModelSpec* model_spec = request.mutable_model_spec();
150 model_spec->clear_name();
153 EXPECT_EQ(absl::StatusCode::kInvalidArgument,
154 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
159 model_spec->set_name(
"test");
160 EXPECT_EQ(tensorflow::error::NOT_FOUND,
161 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
166 TEST_F(TFRTGetModelMetadataImplTest, ReturnsSignaturesForValidModel) {
167 GetModelMetadataRequest request;
168 GetModelMetadataResponse response;
170 ModelSpec* model_spec = request.mutable_model_spec();
171 model_spec->set_name(kTestModelName);
172 model_spec->mutable_version()->set_value(kTestModelVersion);
173 request.add_metadata_field(std::string(kSignatureDef));
175 TF_EXPECT_OK(TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(),
176 request, &response));
177 EXPECT_THAT(response.model_spec(),
178 test_util::EqualsProto(request.model_spec()));
179 EXPECT_EQ(response.metadata_size(), 1);
180 SignatureDefMap received_signature_def_map;
181 response.metadata().at(kSignatureDef).UnpackTo(&received_signature_def_map);
183 SignatureDefMap expected_signature_def_map =
184 GetSignatureDefMap(GetServerCore(), request.model_spec());
185 EXPECT_THAT(response.model_spec(),
186 test_util::EqualsProto(request.model_spec()));
188 EXPECT_EQ(expected_signature_def_map.signature_def().size(),
189 received_signature_def_map.signature_def().size());
191 expected_signature_def_map.signature_def().at(
"regress_x_to_y"),
192 test_util::EqualsProto(
193 received_signature_def_map.signature_def().at(
"regress_x_to_y")));
195 expected_signature_def_map.signature_def().at(
196 kDefaultServingSignatureDefKey),
197 test_util::EqualsProto(received_signature_def_map.signature_def().at(
198 kDefaultServingSignatureDefKey)));
203 TEST_F(TFRTGetModelMetadataImplTest, ModelSpecOverride) {
204 auto request = test_util::CreateProto<GetModelMetadataRequest>(
206 " name: \"test_model\""
208 request.add_metadata_field(std::string(kSignatureDef));
209 auto model_spec_override =
210 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
212 GetModelMetadataResponse response;
213 EXPECT_NE(tensorflow::error::NOT_FOUND,
214 TFRTGetModelMetadataImpl::GetModelMetadata(GetServerCore(), request,
217 EXPECT_EQ(tensorflow::error::NOT_FOUND,
218 TFRTGetModelMetadataImpl::GetModelMetadataWithModelSpec(
219 GetServerCore(), model_spec_override, request, &response)
static Status Create(Options options, std::unique_ptr< ServerCore > *core)