16 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/cc/saved_model/loader.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow_serving/apis/model.pb.h"
28 #include "tensorflow_serving/apis/status.pb.h"
29 #include "tensorflow_serving/core/availability_preserving_policy.h"
30 #include "tensorflow_serving/model_servers/model_platform_types.h"
31 #include "tensorflow_serving/model_servers/platform_config_util.h"
32 #include "tensorflow_serving/model_servers/server_core.h"
33 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
35 #include "tensorflow_serving/test_util/test_util.h"
37 namespace tensorflow {
41 constexpr
char kTestModelBasePath[] =
42 "/servables/tensorflow/testdata/saved_model_half_plus_two_2_versions";
43 constexpr
char kTestModelName[] =
"saved_model_half_plus_two_2_versions";
44 constexpr
char kNonexistentModelName[] =
"nonexistent_model";
45 constexpr
int kTestModelVersion1 = 123;
46 constexpr
int kTestModelVersion2 = 124;
47 constexpr
int kNonexistentModelVersion = 125;
49 class GetModelStatusImplTest :
public ::testing::Test {
51 static void SetUpTestSuite() {
52 TF_ASSERT_OK(CreateServerCore(&server_core_));
55 static void TearDownTestSuite() { server_core_.reset(); }
58 static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
59 ModelServerConfig config;
60 auto* model_config = config.mutable_model_config_list()->add_config();
61 model_config->set_name(kTestModelName);
62 model_config->set_base_path(test_util::TestSrcDirPath(kTestModelBasePath));
63 auto* specific_versions =
64 model_config->mutable_model_version_policy()->mutable_specific();
65 specific_versions->add_versions(kTestModelVersion1);
66 specific_versions->add_versions(kTestModelVersion2);
68 model_config->set_model_platform(kTensorFlowModelPlatform);
72 ServerCore::Options options;
73 options.model_server_config = config;
75 options.platform_config_map =
76 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
79 options.num_initial_load_threads = options.num_load_threads;
80 options.aspired_version_policy =
81 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
85 ServerCore* GetServerCore() {
return server_core_.get(); }
88 static std::unique_ptr<ServerCore> server_core_;
91 std::unique_ptr<ServerCore> GetModelStatusImplTest::server_core_;
93 TEST_F(GetModelStatusImplTest, MissingOrEmptyModelSpecFailure) {
94 GetModelStatusRequest request;
95 GetModelStatusResponse response;
99 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
100 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response)
104 TEST_F(GetModelStatusImplTest, InvalidModelNameFailure) {
105 GetModelStatusRequest request;
106 GetModelStatusResponse response;
108 ModelSpec* model_spec = request.mutable_model_spec();
109 model_spec->set_name(kNonexistentModelName);
114 tensorflow::error::NOT_FOUND,
115 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response)
117 EXPECT_EQ(0, response.model_version_status_size());
120 TEST_F(GetModelStatusImplTest, InvalidModelVersionFailure) {
121 GetModelStatusRequest request;
122 GetModelStatusResponse response;
124 ModelSpec* model_spec = request.mutable_model_spec();
125 model_spec->set_name(kTestModelName);
126 model_spec->mutable_version()->set_value(kNonexistentModelVersion);
130 tensorflow::error::NOT_FOUND,
131 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response)
133 EXPECT_EQ(0, response.model_version_status_size());
136 TEST_F(GetModelStatusImplTest, AllVersionsSuccess) {
137 GetModelStatusRequest request;
138 GetModelStatusResponse response;
140 ModelSpec* model_spec = request.mutable_model_spec();
141 model_spec->set_name(kTestModelName);
146 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response));
147 EXPECT_EQ(2, response.model_version_status_size());
148 std::set<int64_t> expected_versions = {kTestModelVersion1,
150 std::set<int64_t> actual_versions = {
151 response.model_version_status(0).version(),
152 response.model_version_status(1).version()};
153 EXPECT_EQ(expected_versions, actual_versions);
154 EXPECT_EQ(tensorflow::error::OK,
155 response.model_version_status(0).status().error_code());
156 EXPECT_EQ(
"", response.model_version_status(0).status().error_message());
157 EXPECT_EQ(tensorflow::error::OK,
158 response.model_version_status(1).status().error_code());
159 EXPECT_EQ(
"", response.model_version_status(1).status().error_message());
162 TEST_F(GetModelStatusImplTest, SingleVersionSuccess) {
163 GetModelStatusRequest request;
164 GetModelStatusResponse response;
166 ModelSpec* model_spec = request.mutable_model_spec();
167 model_spec->set_name(kTestModelName);
168 model_spec->mutable_version()->set_value(kTestModelVersion1);
172 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response));
173 EXPECT_EQ(1, response.model_version_status_size());
174 EXPECT_EQ(kTestModelVersion1, response.model_version_status(0).version());
175 EXPECT_EQ(tensorflow::error::OK,
176 response.model_version_status(0).status().error_code());
177 EXPECT_EQ(
"", response.model_version_status(0).status().error_message());
182 TEST_F(GetModelStatusImplTest, ModelSpecOverride) {
183 GetModelStatusRequest request;
184 request.mutable_model_spec()->set_name(kTestModelName);
185 auto model_spec_override =
186 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
188 GetModelStatusResponse response;
190 tensorflow::error::NOT_FOUND,
191 GetModelStatusImpl::GetModelStatus(GetServerCore(), request, &response)
193 EXPECT_EQ(tensorflow::error::NOT_FOUND,
194 GetModelStatusImpl::GetModelStatusWithModelSpec(
195 GetServerCore(), model_spec_override, request, &response)
static Status Create(Options options, std::unique_ptr< ServerCore > *core)