16 #include "tensorflow_serving/model_servers/tfrt_http_rest_api_handler.h"
24 #include "rapidjson/document.h"
25 #include "rapidjson/error/en.h"
26 #include <gmock/gmock.h>
27 #include <gtest/gtest.h>
28 #include "absl/strings/escaping.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/time/clock.h"
31 #include "absl/time/time.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/cc/saved_model/signature_constants.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
40 #include "tensorflow_serving/core/availability_preserving_policy.h"
41 #include "tensorflow_serving/model_servers/model_platform_types.h"
42 #include "tensorflow_serving/model_servers/platform_config_util.h"
43 #include "tensorflow_serving/model_servers/server_core.h"
44 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
45 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
46 #include "tensorflow_serving/test_util/test_util.h"
48 namespace tensorflow {
52 using ::testing::HasSubstr;
53 using ::testing::UnorderedElementsAreArray;
55 constexpr
char kTestModelName[] =
"saved_model_half_plus_two_2_versions";
56 constexpr
int kTestModelVersion1 = 123;
57 constexpr
char kTestModelVersionLabel[] =
"Version_Label";
58 constexpr
char kNonexistentModelVersionLabel[] =
"Version_Nonexistent";
59 constexpr
char kTestModelBasePath[] =
"cc/saved_model/testdata/half_plus_two";
61 using HeaderList = std::vector<std::pair<string, string>>;
63 class TFRTHttpRestApiHandlerTest :
public ::testing::Test {
65 static void SetUpTestSuite() {
66 tfrt_stub::SetGlobalRuntime(
67 tfrt_stub::Runtime::Create(4));
69 TF_ASSERT_OK(CreateServerCore(&server_core_));
73 while ((count = server_core_->ListAvailableServableIds().size()) < total) {
74 LOG(INFO) <<
"Available servables: " << count <<
" waiting for " << total;
75 absl::SleepFor(absl::Milliseconds(500));
77 for (
const auto& s : server_core_->ListAvailableServableIds()) {
78 LOG(INFO) <<
"Available servable: " << s.DebugString();
82 static void TearDownTestSuite() { server_core_.reset(); }
85 TFRTHttpRestApiHandlerTest()
86 : handler_(10000, GetServerCore()) {}
88 static Status CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
89 ModelServerConfig config;
90 auto model_config = config.mutable_model_config_list()->add_config();
91 model_config->set_name(kTestModelName);
92 model_config->set_base_path(
93 test_util::TensorflowTestSrcDirPath(kTestModelBasePath));
94 auto* specific_versions =
95 model_config->mutable_model_version_policy()->mutable_specific();
96 specific_versions->add_versions(kTestModelVersion1);
97 model_config->set_model_platform(kTensorFlowModelPlatform);
101 ServerCore::Options options;
102 options.model_server_config = config;
103 PlatformConfigMap platform_config_map;
104 ::google::protobuf::Any source_adapter_config;
105 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
106 *saved_model_bundle_source_adapter_config.mutable_saved_model_config()
107 ->mutable_legacy_config() = SessionBundleConfig();
108 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
109 (*(*platform_config_map
110 .mutable_platform_configs())[kTensorFlowModelPlatform]
111 .mutable_source_adapter_config()) = source_adapter_config;
112 options.platform_config_map = platform_config_map;
113 options.aspired_version_policy =
114 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
117 options.num_initial_load_threads = options.num_load_threads;
121 (*model_config->mutable_version_labels())[kTestModelVersionLabel] =
123 return server_core->get()->ReloadConfig(config);
126 string GetJsonErrorMsg(
const string& json) {
127 rapidjson::Document doc;
128 if (doc.Parse(json.c_str()).HasParseError()) {
129 return absl::StrCat(
"JSON Parse error: ",
130 rapidjson::GetParseError_En(doc.GetParseError()),
133 if (!doc.IsObject()) {
134 return absl::StrCat(
"JSON does not have top-level object: ", json);
136 const auto itr = doc.FindMember(
"error");
137 if (itr == doc.MemberEnd() || !itr->value.IsString()) {
138 return absl::StrCat(
"JSON object does not have \'error\' key: ", json);
140 string escaped_errmsg;
141 escaped_errmsg.assign(itr->value.GetString(), itr->value.GetStringLength());
143 string unescaping_error;
144 if (!absl::CUnescape(escaped_errmsg, &errmsg, &unescaping_error)) {
145 return absl::StrCat(
"Error unescaping JSON error message: ",
151 ServerCore* GetServerCore() {
return server_core_.get(); }
153 TFRTHttpRestApiHandler handler_;
156 static std::unique_ptr<ServerCore> server_core_;
159 std::unique_ptr<ServerCore> TFRTHttpRestApiHandlerTest::server_core_;
161 Status CompareJson(
const string& json1,
const string& json2) {
162 rapidjson::Document doc1;
163 if (doc1.Parse(json1.c_str()).HasParseError()) {
164 return errors::InvalidArgument(
165 "JSON Parse error: ", rapidjson::GetParseError_En(doc1.GetParseError()),
166 " at offset: ", doc1.GetErrorOffset(),
" JSON: ", json1);
168 rapidjson::Document doc2;
169 if (doc2.Parse(json2.c_str()).HasParseError()) {
170 return errors::InvalidArgument(
171 "JSON Parse error: ", rapidjson::GetParseError_En(doc2.GetParseError()),
172 " at offset: ", doc2.GetErrorOffset(),
" JSON: ", json2);
175 return errors::InvalidArgument(
"JSON Different. JSON1: ", json1,
181 TEST_F(TFRTHttpRestApiHandlerTest, kPathRegex) {
182 EXPECT_TRUE(RE2::FullMatch(
"/v1/models", handler_.kPathRegex));
183 EXPECT_FALSE(RE2::FullMatch(
"/statuspage", handler_.kPathRegex));
184 EXPECT_FALSE(RE2::FullMatch(
"/index", handler_.kPathRegex));
187 TEST_F(TFRTHttpRestApiHandlerTest, UnsupportedApiCalls) {
189 string model_name, method, output;
191 status = handler_.ProcessRequest(
"GET",
"/v1/foo",
"", &headers, &model_name,
193 EXPECT_TRUE(errors::IsInvalidArgument(status));
194 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
196 status = handler_.ProcessRequest(
"POST",
"/v1/foo",
"", &headers, &model_name,
198 EXPECT_TRUE(errors::IsInvalidArgument(status));
199 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
201 status = handler_.ProcessRequest(
"GET",
"/v1/models",
"", &headers,
202 &model_name, &method, &output);
203 EXPECT_TRUE(errors::IsInvalidArgument(status));
204 EXPECT_THAT(status.message(), HasSubstr(
"Missing model name"));
206 status = handler_.ProcessRequest(
"POST",
"/v1/models",
"", &headers,
207 &model_name, &method, &output);
208 EXPECT_TRUE(errors::IsInvalidArgument(status));
209 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
211 status = handler_.ProcessRequest(
"GET",
"/v1/models/foo:predict",
"",
212 &headers, &model_name, &method, &output);
213 EXPECT_TRUE(errors::IsInvalidArgument(status));
214 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
216 status = handler_.ProcessRequest(
"GET",
"/v1/models/foo/version/50:predict",
217 "", &headers, &model_name, &method, &output);
218 EXPECT_TRUE(errors::IsInvalidArgument(status));
219 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
221 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/version/50:regress",
222 "", &headers, &model_name, &method, &output);
223 EXPECT_TRUE(errors::IsInvalidArgument(status));
224 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
227 handler_.ProcessRequest(
"POST",
"/v1/models/foo/versions/HELLO:regress",
228 "", &headers, &model_name, &method, &output);
229 EXPECT_TRUE(errors::IsInvalidArgument(status));
230 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
232 status = handler_.ProcessRequest(
234 absl::StrCat(
"/v1/models/foo/versions/",
235 std::numeric_limits<uint64_t>::max(),
":regress"),
236 "", &headers, &model_name, &method, &output);
237 EXPECT_TRUE(errors::IsInvalidArgument(status));
238 EXPECT_THAT(status.message(), HasSubstr(
"Failed to convert version"));
240 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/metadata",
"",
241 &headers, &model_name, &method, &output);
242 EXPECT_TRUE(errors::IsInvalidArgument(status));
243 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
246 handler_.ProcessRequest(
"POST",
"/v1/models/foo/label/some_label:regress",
247 "", &headers, &model_name, &method, &output);
248 EXPECT_TRUE(errors::IsInvalidArgument(status));
249 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
251 status = handler_.ProcessRequest(
252 "POST",
"/v1/models/foo/versions/50/labels/some_label:regress",
"",
253 &headers, &model_name, &method, &output);
254 EXPECT_TRUE(errors::IsInvalidArgument(status));
255 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
257 status = handler_.ProcessRequest(
"POST",
258 "/v1/models/foo/versions/some_label:regress",
259 "", &headers, &model_name, &method, &output);
260 EXPECT_TRUE(errors::IsInvalidArgument(status));
261 EXPECT_THAT(status.message(), HasSubstr(
"Malformed request"));
264 TEST_F(TFRTHttpRestApiHandlerTest, PredictModelNameVersionErrors) {
266 string model_name, method, output;
269 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo:predict",
270 R
"({ "instances": [1] })", &headers,
271 &model_name, &method, &output);
272 EXPECT_TRUE(errors::IsNotFound(status));
275 status = handler_.ProcessRequest(
"POST",
"/v1/models/foo/versions/50:predict",
276 R
"({ "instances": [1] })", &headers,
277 &model_name, &method, &output);
278 EXPECT_TRUE(errors::IsNotFound(status));
281 status = handler_.ProcessRequest(
282 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
"99:predict"),
283 R
"({ "instances": [1] })", &headers, &model_name, &method, &output);
284 EXPECT_TRUE(errors::IsNotFound(status));
287 TEST_F(TFRTHttpRestApiHandlerTest, PredictRequestErrors) {
289 string model_name, method, output;
291 const string& req_path =
292 absl::StrCat(
"/v1/models/", kTestModelName,
":predict");
295 status = handler_.ProcessRequest(
"POST", req_path,
"", &headers, &model_name,
297 EXPECT_TRUE(errors::IsInvalidArgument(status));
298 EXPECT_THAT(status.message(),
299 HasSubstr(
"JSON Parse error: The document is empty"));
302 status = handler_.ProcessRequest(
"POST", req_path,
"instances = [1, 2]",
303 &headers, &model_name, &method, &output);
304 EXPECT_TRUE(errors::IsInvalidArgument(status));
305 EXPECT_THAT(status.message(), HasSubstr(
"JSON Parse error: Invalid value"));
308 status = handler_.ProcessRequest(
"POST", req_path,
309 R
"({ "instances": ["x", "y"] })", &headers,
310 &model_name, &method, &output);
311 EXPECT_TRUE(errors::IsInvalidArgument(status));
312 EXPECT_THAT(status.message(), HasSubstr("not of expected type: float"));
315 status = handler_.ProcessRequest(
317 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
318 kNonexistentModelVersionLabel,
":predict"),
319 R
"({ "instances": ["x", "y"] })", &headers, &model_name, &method,
321 EXPECT_TRUE(errors::IsInvalidArgument(status));
322 EXPECT_THAT(status.message(),
323 HasSubstr("Unrecognized servable version label"));
327 handler_.ProcessRequest(
"POST", req_path, R
"({ "signature_name": 100 })",
328 &headers, &model_name, &method, &output);
329 EXPECT_TRUE(errors::IsInvalidArgument(status));
330 EXPECT_THAT(GetJsonErrorMsg(output),
331 HasSubstr("'signature_name' key must be a string value."));
334 TEST_F(TFRTHttpRestApiHandlerTest, Predict) {
336 string model_name, method, output;
339 TF_EXPECT_OK(handler_.ProcessRequest(
340 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":predict"),
341 R
"({"instances": [[1.0, 2.0], [3.0, 4.0]]})", &headers, &model_name,
344 CompareJson(output, R"({ "predictions": [[2.5, 3.0], [3.5, 4.0]] })"));
345 EXPECT_THAT(headers, UnorderedElementsAreArray(
346 (HeaderList){{"Content-Type",
"application/json"}}));
349 TF_EXPECT_OK(handler_.ProcessRequest(
351 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
352 kTestModelVersion1,
":predict"),
353 R
"({"instances": [1.0, 2.0]})", &headers, &model_name, &method, &output));
354 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [2.5, 3.0] })"));
355 EXPECT_THAT(headers, UnorderedElementsAreArray(
356 (HeaderList){{"Content-Type",
"application/json"}}));
359 TF_EXPECT_OK(handler_.ProcessRequest(
361 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
362 kTestModelVersionLabel,
":predict"),
363 R
"({"instances": [1.0, 2.0]})", &headers, &model_name, &method, &output));
364 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [2.5, 3.0] })"));
365 EXPECT_THAT(headers, UnorderedElementsAreArray(
366 (HeaderList){{"Content-Type",
"application/json"}}));
369 TF_EXPECT_OK(handler_.ProcessRequest(
371 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
372 kTestModelVersion1,
":predict"),
373 R
"({"signature_name": "serving_default", "instances": [3.0, 4.0]})",
374 &headers, &model_name, &method, &output));
375 TF_EXPECT_OK(CompareJson(output, R"({ "predictions": [3.5, 4.0] })"));
376 EXPECT_THAT(headers, UnorderedElementsAreArray(
377 (HeaderList){{"Content-Type",
"application/json"}}));
380 TF_EXPECT_OK(handler_.ProcessRequest(
382 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
383 kTestModelVersion1,
":predict"),
384 R
"({"signature_name": "serving_default", "inputs": [3.0, 4.0]})",
385 &headers, &model_name, &method, &output));
386 TF_EXPECT_OK(CompareJson(output, R"({ "outputs": [3.5, 4.0] })"));
387 EXPECT_THAT(headers, UnorderedElementsAreArray(
388 (HeaderList){{"Content-Type",
"application/json"}}));
391 TEST_F(TFRTHttpRestApiHandlerTest, Regress) {
393 string model_name, method, output;
396 TF_EXPECT_OK(handler_.ProcessRequest(
397 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":regress"),
398 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
399 &headers, &model_name, &method, &output));
400 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
401 EXPECT_THAT(headers, UnorderedElementsAreArray(
402 (HeaderList){{"Content-Type",
"application/json"}}));
405 TF_EXPECT_OK(handler_.ProcessRequest(
407 absl::StrCat(
"/v1/models/", kTestModelName,
"/versions/",
408 kTestModelVersion1,
":regress"),
409 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
410 &headers, &model_name, &method, &output));
411 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
412 EXPECT_THAT(headers, UnorderedElementsAreArray(
413 (HeaderList){{"Content-Type",
"application/json"}}));
416 TF_EXPECT_OK(handler_.ProcessRequest(
418 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
419 kTestModelVersionLabel,
":regress"),
420 R
"({"signature_name": "regress_x_to_y", "examples": [ { "x": 80.0 } ] })",
421 &headers, &model_name, &method, &output));
422 TF_EXPECT_OK(CompareJson(output, R"({ "results": [42] })"));
423 EXPECT_THAT(headers, UnorderedElementsAreArray(
424 (HeaderList){{"Content-Type",
"application/json"}}));
427 TEST_F(TFRTHttpRestApiHandlerTest, Classify) {
429 string model_name, method, output;
432 TF_EXPECT_OK(handler_.ProcessRequest(
433 "POST", absl::StrCat(
"/v1/models/", kTestModelName,
":classify"),
434 R
"({"signature_name": "classify_x_to_y", "examples": [ { "x": 20.0 } ] })",
435 &headers, &model_name, &method, &output));
436 TF_EXPECT_OK(CompareJson(output, R"({ "results": [[["", 12]]] })"));
437 EXPECT_THAT(headers, UnorderedElementsAreArray(
438 (HeaderList){{"Content-Type",
"application/json"}}));
441 TF_EXPECT_OK(handler_.ProcessRequest(
443 absl::StrCat(
"/v1/models/", kTestModelName,
"/labels/",
444 kTestModelVersionLabel,
":classify"),
445 R
"({"signature_name": "classify_x_to_y", "examples": [ { "x": 10.0 } ] })",
446 &headers, &model_name, &method, &output));
447 TF_EXPECT_OK(CompareJson(output, R"({ "results": [[["", 7]]] })"));
448 EXPECT_THAT(headers, UnorderedElementsAreArray(
449 (HeaderList){{"Content-Type",
"application/json"}}));
452 TEST_F(TFRTHttpRestApiHandlerTest, GetStatus) {
454 string model_name, method, output;
458 TF_EXPECT_OK(handler_.ProcessRequest(
459 "GET", absl::StrCat(
"/v1/models/", kTestModelName),
"", &headers,
460 &model_name, &method, &output));
461 EXPECT_THAT(headers, UnorderedElementsAreArray(
462 (HeaderList){{
"Content-Type",
"application/json"}}));
463 TF_EXPECT_OK(CompareJson(output, R
"({
464 "model_version_status": [
467 "state": "AVAILABLE",
478 handler_.ProcessRequest(
"GET",
479 absl::StrCat(
"/v1/models/", kTestModelName,
480 "/versions/", kTestModelVersion1),
481 "", &headers, &model_name, &method, &output));
482 EXPECT_THAT(headers, UnorderedElementsAreArray(
483 (HeaderList){{
"Content-Type",
"application/json"}}));
484 TF_EXPECT_OK(CompareJson(output, R
"({
485 "model_version_status": [
488 "state": "AVAILABLE",
499 handler_.ProcessRequest(
"GET",
500 absl::StrCat(
"/v1/models/", kTestModelName,
501 "/labels/", kTestModelVersionLabel),
502 "", &headers, &model_name, &method, &output));
503 EXPECT_THAT(headers, UnorderedElementsAreArray(
504 (HeaderList){{
"Content-Type",
"application/json"}}));
505 TF_EXPECT_OK(CompareJson(output, R
"({
506 "model_version_status": [
509 "state": "AVAILABLE",
519 TEST_F(TFRTHttpRestApiHandlerTest, GetModelMetadata) {
521 string model_name, method, output;
523 string test_file_contents;
526 TF_EXPECT_OK(handler_.ProcessRequest(
527 "GET", absl::StrCat(
"/v1/models/", kTestModelName,
"/metadata"),
"",
528 &headers, &model_name, &method, &output));
529 EXPECT_THAT(headers, UnorderedElementsAreArray(
530 (HeaderList){{
"Content-Type",
"application/json"}}));
531 const string fname = absl::StrCat(
532 "./tensorflow_serving/servables/tensorflow/testdata/"
533 "saved_model_half_plus_two_2_versions_metadata.json");
534 TF_EXPECT_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), fname,
535 &test_file_contents));
536 TF_EXPECT_OK(CompareJson(output, test_file_contents));
static Status Create(Options options, std::unique_ptr< ServerCore > *core)