16 #include "tensorflow_serving/model_servers/http_rest_api_handler.h"
24 #include "rapidjson/document.h"
25 #include "rapidjson/error/en.h"
26 #include <gmock/gmock.h>
27 #include <gtest/gtest.h>
28 #include "absl/strings/escaping.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/time/clock.h"
32 #include "tensorflow/cc/saved_model/loader.h"
33 #include "tensorflow/cc/saved_model/signature_constants.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tsl/platform/errors.h"
38 #include "tensorflow_serving/core/availability_preserving_policy.h"
39 #include "tensorflow_serving/model_servers/model_platform_types.h"
40 #include "tensorflow_serving/model_servers/platform_config_util.h"
41 #include "tensorflow_serving/model_servers/server_core.h"
42 #include "tensorflow_serving/model_servers/server_init.h"
43 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
44 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
45 #include "tensorflow_serving/test_util/test_util.h"
47 namespace tensorflow {
51 using ::testing::HasSubstr;
52 using ::testing::UnorderedElementsAreArray;
54 constexpr
char kTestModelBasePath[] =
"cc/saved_model/testdata/half_plus_two";
55 constexpr
char kTestModelName[] =
"saved_model_half_plus_two_2_versions";
56 constexpr
int kTestModelVersion1 = 123;
57 constexpr
char kTestModelVersionLabel[] =
"Version_Label";
58 constexpr
char kNonexistentModelVersionLabel[] =
"Version_Nonexistent";
60 using HeaderList = std::vector<std::pair<string, string>>;
62 class HttpRestApiHandlerTest :
public ::testing::Test {
64 static void SetUpTestSuite() {
65 TF_ASSERT_OK(CreateServerCore(&server_core_));
69 while ((count = server_core_->ListAvailableServableIds().size()) < total) {
70 LOG(INFO) <<
"Available servables: " << count <<
" waiting for " << total;
71 absl::SleepFor(absl::Milliseconds(500));
73 for (
const auto& s : server_core_->ListAvailableServableIds()) {
74 LOG(INFO) <<
"Available servable: " << s.DebugString();
78 static void TearDownTestSuite() { server_core_.reset(); }
81 HttpRestApiHandlerTest() : handler_(-1, GetServerCore()) {}
83 static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
84 ModelServerConfig config;
85 auto* model_config = config.mutable_model_config_list()->add_config();
86 model_config->set_name(kTestModelName);
87 model_config->set_base_path(
88 test_util::TensorflowTestSrcDirPath(kTestModelBasePath));
89 auto* specific_versions =
90 model_config->mutable_model_version_policy()->mutable_specific();
91 specific_versions->add_versions(kTestModelVersion1);
93 model_config->set_model_platform(kTensorFlowModelPlatform);
97 ServerCore::Options options;
98 options.model_server_config = config;
100 auto* tf_serving_registry =
101 init::TensorflowServingFunctionRegistration::GetRegistry();
102 TF_RETURN_IF_ERROR(tf_serving_registry->GetSetupPlatformConfigMap()(
103 SessionBundleConfig(), options.platform_config_map));
106 options.num_initial_load_threads = options.num_load_threads;
107 options.aspired_version_policy =
108 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
112 (*model_config->mutable_version_labels())[kTestModelVersionLabel] =
114 return server_core->get()->ReloadConfig(config);
117 string GetJsonErrorMsg(
const string& json) {
118 rapidjson::Document doc;
119 if (doc.Parse(json.c_str()).HasParseError()) {
120 return absl::StrCat(
"JSON Parse error: ",
121 rapidjson::GetParseError_En(doc.GetParseError()),
124 if (!doc.IsObject()) {
125 return absl::StrCat(
"JSON does not have top-level object: ", json);
127 const auto itr = doc.FindMember(
"error");
128 if (itr == doc.MemberEnd() || !itr->value.IsString()) {
129 return absl::StrCat(
"JSON object does not have \'error\' key: ", json);
131 string escaped_errmsg;
132 escaped_errmsg.assign(itr->value.GetString(), itr->value.GetStringLength());
134 string unescaping_error;
135 if (!absl::CUnescape(escaped_errmsg, &errmsg, &unescaping_error)) {
136 return absl::StrCat(
"Error unescaping JSON error message: ",
142 ServerCore* GetServerCore() {
return server_core_.get(); }
144 HttpRestApiHandler handler_;
147 static std::unique_ptr<ServerCore> server_core_;
150 std::unique_ptr<ServerCore> HttpRestApiHandlerTest::server_core_;
152 Status CompareJson(
const string& json1,
const string& json2) {
153 rapidjson::Document doc1;
154 if (doc1.Parse(json1.c_str()).HasParseError()) {
155 return errors::InvalidArgument(
156 "JSON Parse error: ", rapidjson::GetParseError_En(doc1.GetParseError()),
157 " at offset: ", doc1.GetErrorOffset(),
" JSON: ", json1);
159 rapidjson::Document doc2;
160 if (doc2.Parse(json2.c_str()).HasParseError()) {
161 return errors::InvalidArgument(
162 "JSON Parse error: ", rapidjson::GetParseError_En(doc2.GetParseError()),
163 " at offset: ", doc2.GetErrorOffset(),
" JSON: ", json2);
166 return errors::InvalidArgument(
"JSON Different. JSON1: ", json1,
169 return absl::OkStatus();
172 TEST_F(HttpRestApiHandlerTest, kPathRegex) {
173 EXPECT_TRUE(RE2::FullMatch(
"/v1/models", handler_.kPathRegex));
174 EXPECT_FALSE(RE2::FullMatch(
"/statuspage", handler_.kPathRegex));
175 EXPECT_FALSE(RE2::FullMatch(
"/index", handler_.kPathRegex));
178 TEST_F(HttpRestApiHandlerTest, UnsupportedApiCalls) {
180 string model_name, method, output;
182 status = handler_.ProcessRequest(
"GET",
"/v1/foo",
"", &headers, &model_name,
184 EXPECT_TRUE(errors::IsInvalidArgument(status));
185 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
187 status = handler_.ProcessRequest(
"POST",
"/v1/foo",
"", &headers, &model_name,
189 EXPECT_TRUE(errors::IsInvalidArgument(status));
190 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
192 status = handler_.ProcessRequest(
"GET",
"/v1/models",
"", &headers,
193 &model_name, &method, &output);
194 EXPECT_TRUE(errors::IsInvalidArgument(status));
195 EXPECT_THAT(status.message(), HasSubstr(
"Missing model name"));
196 status = handler_.ProcessRequest(
"GET",
"/v1/models/debug/model_name",
"",
197 &headers, &model_name, &method, &output);
198 EXPECT_TRUE(errors::IsInvalidArgument(status));
199 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
201 status = handler_.ProcessRequest(
"POST",
"/v1/models",
"", &headers,
202 &model_name, &method, &output);
203 EXPECT_TRUE(errors::IsInvalidArgument(status));
204 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
206 status = handler_.ProcessRequest(
"GET",
"/v1/models/foo:predict",
"",
207 &headers, &model_name, &method, &output);
208 EXPECT_TRUE(errors::IsInvalidArgument(status));
209 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
211 status = handler_.ProcessRequest(
"GET",
"/v1/models/foo/version/50:predict",
212 "", &headers, &model_name, &method, &output);
213 EXPECT_TRUE(errors::IsInvalidArgument(status));
214 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
216 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/version/50:regress",
217 "", &headers, &model_name, &method, &output);
218 EXPECT_TRUE(errors::IsInvalidArgument(status));
219 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
222 handler_.ProcessRequest(
"POST",
"/v1/models/foo/versions/HELLO:regress",
223 "", &headers, &model_name, &method, &output);
224 EXPECT_TRUE(errors::IsInvalidArgument(status));
225 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
227 status = handler_.ProcessRequest(
229 absl::StrCat(
"/v1/models/foo/versions/",
230 std::numeric_limits<uint64_t>::max(),
":regress"),
231 "", &headers, &model_name, &method, &output);
232 EXPECT_TRUE(errors::IsInvalidArgument(status));
233 EXPECT_THAT(status.message(), HasSubstr(
"Failed to convert version"));
235 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/metadata",
"",
236 &headers, &model_name, &method, &output);
237 EXPECT_TRUE(errors::IsInvalidArgument(status));
238 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
241 handler_.ProcessRequest(
"POST",
"/v1/models/foo/label/some_label:regress",
242 "", &headers, &model_name, &method, &output);
243 EXPECT_TRUE(errors::IsInvalidArgument(status));
244 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
246 status = handler_.ProcessRequest(
247 "POST",
"/v1/models/foo/versions/50/labels/some_label:regress",
"",
248 &headers, &model_name, &method, &output);
249 EXPECT_TRUE(errors::IsInvalidArgument(status));
250 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
252 status = handler_.ProcessRequest(
"POST",
253 "/v1/models/foo/versions/some_label:regress",
254 "", &headers, &model_name, &method, &output);
255 EXPECT_TRUE(errors::IsInvalidArgument(status));
256 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
259 TEST_F(HttpRestApiHandlerTest, PredictModelNameVersionErrors) {
261 string model_name, method, output;
264 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo:predict",
265 R
"({ "instances": [1] })", &headers,
266 &model_name, &method, &output);
267 EXPECT_TRUE(errors::IsNotFound(status));
270 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/versions/50:predict",
271 R
"({ "instances": [1] })", &headers,
272 &model_name, &method, &output);
273 EXPECT_TRUE(errors::IsNotFound(status));
276 status = handler_.ProcessRequest(
277 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
"99:predict"),
278 R
"({ "instances": [1] })", &headers, &model_name, &method, &output);
279 EXPECT_TRUE(errors::IsNotFound(status));
282 TEST_F(HttpRestApiHandlerTest, PredictRequestErrors) {
284 string model_name, method, output;
286 const string& req_path =
287 absl::StrCat(
"/v1/models/", kTestModelName,
":predict");
290 status = handler_.ProcessRequest(
"POST", req_path,
"", &headers, &model_name,
292 EXPECT_TRUE(errors::IsInvalidArgument(status));
293 EXPECT_THAT(status.message(),
294 HasSubstr(
"JSON Parse error: The document is empty"));
297 status = handler_.ProcessRequest(
"POST", req_path,
"instances = [1, 2]",
298 &headers, &model_name, &method, &output);
299 EXPECT_TRUE(errors::IsInvalidArgument(status));
300 EXPECT_THAT(status.message(), HasSubstr(
"JSON Parse error: Invalid value"));
303 status = handler_.ProcessRequest(
"POST", req_path,
304 R
"({ "instances": ["x", "y"] })", &headers,
305 &model_name, &method, &output);
306 EXPECT_TRUE(errors::IsInvalidArgument(status));
307 EXPECT_THAT(status.message(), HasSubstr("not of expected type: float"));
310 status = handler_.ProcessRequest(
312 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
313 kNonexistentModelVersionLabel,
":predict"),
314 R
"({ "instances": ["x", "y"] })", &headers, &model_name, &method,
316 EXPECT_TRUE(errors::IsInvalidArgument(status));
317 EXPECT_THAT(status.message(),
318 HasSubstr("Unrecognized servable version label"));
322 handler_.ProcessRequest(
"POST", req_path, R
"({ "signature_name": 100 })",
323 &headers, &model_name, &method, &output);
324 EXPECT_TRUE(errors::IsInvalidArgument(status));
325 EXPECT_THAT(GetJsonErrorMsg(output),
326 HasSubstr("'signature_name' key must be a string value."));
329 TEST_F(HttpRestApiHandlerTest, Predict) {
331 string model_name, method, output;
334 TF_EXPECT_OK(handler_.ProcessRequest(
335 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":predict"),
336 R
"({"instances": [[1.0, 2.0], [3.0, 4.0]]})", &headers, &model_name,
339 CompareJson(output, R"({ "predictions": [[2.5, 3.0], [3.5, 4.0]] })"));
340 EXPECT_THAT(headers, UnorderedElementsAreArray(
341 (HeaderList){{"Content-Type",
"application/json"}}));
344 TF_EXPECT_OK(handler_.ProcessRequest(
346 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
347 kTestModelVersion1,
":predict"),
348 R
"({"instances": [1.0, 2.0]})", &headers, &model_name, &method, &output));
349 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [2.5, 3.0] })"));
350 EXPECT_THAT(headers, UnorderedElementsAreArray(
351 (HeaderList){{"Content-Type",
"application/json"}}));
354 TF_EXPECT_OK(handler_.ProcessRequest(
356 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
357 kTestModelVersionLabel,
":predict"),
358 R
"({"instances": [1.0, 2.0]})", &headers, &model_name, &method, &output));
359 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [2.5, 3.0] })"));
360 EXPECT_THAT(headers, UnorderedElementsAreArray(
361 (HeaderList){{"Content-Type",
"application/json"}}));
364 TF_EXPECT_OK(handler_.ProcessRequest(
366 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
367 kTestModelVersion1,
":predict"),
368 R
"({"signature_name": "serving_default", "instances": [3.0, 4.0]})",
369 &headers, &model_name, &method, &output));
370 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [3.5, 4.0] })"));
371 EXPECT_THAT(headers, UnorderedElementsAreArray(
372 (HeaderList){{"Content-Type",
"application/json"}}));
375 TF_EXPECT_OK(handler_.ProcessRequest(
377 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
378 kTestModelVersion1,
":predict"),
379 R
"({"signature_name": "serving_default", "inputs": [3.0, 4.0]})",
380 &headers, &model_name, &method, &output));
381 TF_EXPECT_OK(CompareJson(output, R"({ "outputs": [3.5, 4.0] })"));
382 EXPECT_THAT(headers, UnorderedElementsAreArray(
383 (HeaderList){{"Content-Type",
"application/json"}}));
386 TEST_F(HttpRestApiHandlerTest, Regress) {
388 string model_name, method, output;
391 TF_EXPECT_OK(handler_.ProcessRequest(
392 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":regress"),
393 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
394 &headers, &model_name, &method, &output));
395 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
396 EXPECT_THAT(headers, UnorderedElementsAreArray(
397 (HeaderList){{"Content-Type",
"application/json"}}));
400 TF_EXPECT_OK(handler_.ProcessRequest(
402 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
403 kTestModelVersion1,
":regress"),
404 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
405 &headers, &model_name, &method, &output));
406 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
407 EXPECT_THAT(headers, UnorderedElementsAreArray(
408 (HeaderList){{"Content-Type",
"application/json"}}));
411 TF_EXPECT_OK(handler_.ProcessRequest(
413 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
414 kTestModelVersionLabel,
":regress"),
415 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
416 &headers, &model_name, &method, &output));
417 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
418 EXPECT_THAT(headers, UnorderedElementsAreArray(
419 (HeaderList){{"Content-Type",
"application/json"}}));
422 TEST_F(HttpRestApiHandlerTest, Classify) {
424 string model_name, method, output;
427 TF_EXPECT_OK(handler_.ProcessRequest(
428 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":classify"),
429 R
"({"signature_name": "classify_x_to_y", "examples": [ { "x": 20.0 } ] })",
430 &headers, &model_name, &method, &output));
431 TF_EXPECT_OK(CompareJson(output, R"({ "results": [[["", 12]]] })"));
432 EXPECT_THAT(headers, UnorderedElementsAreArray(
433 (HeaderList){{"Content-Type",
"application/json"}}));
436 TF_EXPECT_OK(handler_.ProcessRequest(
438 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
439 kTestModelVersionLabel,
":classify"),
440 R
"({"signature_name": "classify_x_to_y", "examples": [ { "x": 10.0 } ] })",
441 &headers, &model_name, &method, &output));
442 TF_EXPECT_OK(CompareJson(output, R"({ "results": [[["", 7]]] })"));
443 EXPECT_THAT(headers, UnorderedElementsAreArray(
444 (HeaderList){{"Content-Type",
"application/json"}}));
447 TEST_F(HttpRestApiHandlerTest, GetStatus) {
449 string model_name, method, output;
453 TF_EXPECT_OK(handler_.ProcessRequest(
454 "GET", absl::StrCat(
"/v1/models/", kTestModelName),
"", &headers,
455 &model_name, &method, &output));
456 EXPECT_THAT(headers, UnorderedElementsAreArray(
457 (HeaderList){{
"Content-Type",
"application/json"}}));
458 TF_EXPECT_OK(CompareJson(output, R
"({
459 "model_version_status": [
462 "state": "AVAILABLE",
473 handler_.ProcessRequest(
"GET",
474 absl::StrCat(
"/v1/models/", kTestModelName,
475 "/versions/", kTestModelVersion1),
476 "", &headers, &model_name, &method, &output));
477 EXPECT_THAT(headers, UnorderedElementsAreArray(
478 (HeaderList){{
"Content-Type",
"application/json"}}));
479 TF_EXPECT_OK(CompareJson(output, R
"({
480 "model_version_status": [
483 "state": "AVAILABLE",
494 handler_.ProcessRequest(
"GET",
495 absl::StrCat(
"/v1/models/", kTestModelName,
496 "/labels/", kTestModelVersionLabel),
497 "", &headers, &model_name, &method, &output));
498 EXPECT_THAT(headers, UnorderedElementsAreArray(
499 (HeaderList){{
"Content-Type",
"application/json"}}));
500 TF_EXPECT_OK(CompareJson(output, R
"({
501 "model_version_status": [
504 "state": "AVAILABLE",
514 TEST_F(HttpRestApiHandlerTest, GetModelMetadata) {
516 string model_name, method, output;
517 string test_file_contents;
520 TF_EXPECT_OK(handler_.ProcessRequest(
521 "GET", absl::StrCat(
"/v1/models/", kTestModelName,
"/metadata"),
"",
522 &headers, &model_name, &method, &output));
523 EXPECT_THAT(headers, UnorderedElementsAreArray(
524 (HeaderList){{
"Content-Type",
"application/json"}}));
525 const string fname = absl::StrCat(
526 "./tensorflow_serving/servables/tensorflow/testdata",
527 "/saved_model_half_plus_two_2_versions_metadata.json");
528 TF_EXPECT_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), fname,
529 &test_file_contents));
530 TF_EXPECT_OK(CompareJson(output, test_file_contents));
static Status Create(Options options, std::unique_ptr< ServerCore > *core)