TensorFlow Serving C++ API Documentation
http_rest_api_util.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "google/protobuf/util/json_util.h"
22 #include "absl/strings/numbers.h"
23 #include <curl/curl.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
27 
28 namespace tensorflow {
29 namespace serving {
30 
31 const char* const kPredictionApiGegex =
32  R"((?i)/v1/models/([^/:]+)(?:(?:/versions/(\d+))|(?:/labels/([^/:]+)))?:(classify|regress|predict))";
33 const char* const kModelStatusApiRegex =
34  R"((?i)/v1/models(?:/([^/:]+))?(?:(?:/versions/(\d+))|(?:/labels/([^/:]+)))?(?:\/(metadata))?)";
35 
36 void AddHeaders(std::vector<std::pair<string, string>>* headers) {
37  headers->push_back({"Content-Type", "application/json"});
38 }
39 
40 void AddCORSHeaders(std::vector<std::pair<string, string>>* headers) {
41  headers->push_back({"Access-Control-Allow-Origin", "*"});
42  headers->push_back({"Access-Control-Allow-Methods", "POST, GET"});
43  headers->push_back({"Access-Control-Allow-Headers", "Content-Type"});
44 }
45 
46 Status FillModelSpecWithNameVersionAndLabel(
47  const absl::string_view model_name,
48  const absl::optional<int64_t>& model_version,
49  const absl::optional<absl::string_view> model_version_label,
50  ::tensorflow::serving::ModelSpec* model_spec) {
51  model_spec->set_name(string(model_name));
52 
53  if (model_version.has_value() && model_version_label.has_value()) {
54  return errors::InvalidArgument(
55  "Both model version (", model_version.value(),
56  ") and model version label (", model_version_label.value(),
57  ") cannot be supplied.");
58  }
59 
60  if (model_version.has_value()) {
61  model_spec->mutable_version()->set_value(model_version.value());
62  }
63  if (model_version_label.has_value()) {
64  model_spec->set_version_label(string(model_version_label.value()));
65  }
66  return absl::OkStatus();
67 }
68 
69 bool DecodeArg(string* arg) {
70  static const bool run_once ABSL_ATTRIBUTE_UNUSED = [&]() {
71  curl_global_init(CURL_GLOBAL_ALL);
72  return true;
73  }();
74  CURL* curl = curl_easy_init();
75  if (curl != nullptr) {
76  int outlength;
77  char* cres =
78  curl_easy_unescape(curl, arg->c_str(), arg->size(), &outlength);
79  if (cres == nullptr) {
80  return false;
81  }
82  arg->assign(cres, outlength);
83  curl_free(cres);
84  curl_easy_cleanup(curl);
85  return true;
86  }
87  return false;
88 }
89 
90 Status ParseModelInfo(const absl::string_view http_method,
91  const absl::string_view request_path, string* model_name,
92  absl::optional<int64_t>* model_version,
93  absl::optional<string>* model_version_label,
94  string* method, string* model_subresource,
95  bool* parse_successful) {
96  string model_version_str;
97  string model_version_label_str;
98  // Parse request parameters
99  if (http_method == "POST") {
100  *parse_successful =
101  RE2::FullMatch(string(request_path), kPredictionApiGegex, model_name,
102  &model_version_str, &model_version_label_str, method);
103  } else if (http_method == "GET") {
104  *parse_successful = RE2::FullMatch(
105  string(request_path), kModelStatusApiRegex, model_name,
106  &model_version_str, &model_version_label_str, model_subresource);
107  }
108  if (!model_name->empty()) {
109  if (!DecodeArg(model_name)) {
110  return errors::InvalidArgument("Failed to decode model name:",
111  *model_name);
112  }
113  }
114  if (!model_version_str.empty()) {
115  int64_t version;
116  if (!absl::SimpleAtoi(model_version_str, &version)) {
117  return errors::InvalidArgument(
118  "Failed to convert version: ", model_version_str, " to numeric.");
119  }
120  *model_version = version;
121  }
122  if (!model_version_label_str.empty()) {
123  if (!DecodeArg(&model_version_label_str)) {
124  return errors::InvalidArgument("Failed to decode model version label:",
125  model_version_label_str);
126  }
127  *model_version_label = model_version_label_str;
128  }
129  return absl::OkStatus();
130 }
131 
132 Status ToJsonString(const GetModelStatusResponse& response, string* output) {
133  google::protobuf::util::JsonPrintOptions opts;
134  opts.add_whitespace = true;
135  opts.always_print_primitive_fields = true;
136  // Note this is protobuf::util::Status (not TF Status) object.
137  const auto& status = MessageToJsonString(response, output, opts);
138  if (!status.ok()) {
139  return errors::Internal("Failed to convert proto to json. Error: ",
140  status.ToString());
141  }
142  return absl::OkStatus();
143 }
144 
145 Status ToJsonString(const GetModelMetadataResponse& response, string* output) {
146  google::protobuf::util::JsonPrintOptions opts;
147  opts.add_whitespace = true;
148  opts.always_print_primitive_fields = true;
149  // TODO(b/118381513): preserving proto field names on 'Any' fields has been
150  // fixed in the master branch of OSS protobuf but the TF ecosystem is
151  // currently using v3.6.0 where the fix is not present. To resolve the issue
152  // we invoke MessageToJsonString on invididual fields and concatenate the
153  // resulting strings and make it valid JSON that conforms with the response we
154  // expect.
155  opts.preserve_proto_field_names = true;
156 
157  string model_spec_output;
158  const auto& status1 =
159  MessageToJsonString(response.model_spec(), &model_spec_output, opts);
160  if (!status1.ok()) {
161  return errors::Internal(
162  "Failed to convert model spec proto to json. Error: ",
163  status1.ToString());
164  }
165 
166  tensorflow::serving::SignatureDefMap signature_def_map;
167  if (response.metadata().end() ==
168  response.metadata().find(GetModelMetadataImpl::kSignatureDef)) {
169  return errors::Internal(
170  "Failed to find 'signature_def' key in the GetModelMetadataResponse "
171  "metadata map.");
172  }
173  bool unpack_status = response.metadata()
174  .at(GetModelMetadataImpl::kSignatureDef)
175  .UnpackTo(&signature_def_map);
176  if (!unpack_status) {
177  return errors::Internal(
178  "Failed to unpack 'Any' object to 'SignatureDefMap'.");
179  }
180 
181  string signature_def_output;
182  const auto& status2 =
183  MessageToJsonString(signature_def_map, &signature_def_output, opts);
184  if (!status2.ok()) {
185  return errors::Internal(
186  "Failed to convert signature def proto to json. Error: ",
187  status2.ToString());
188  }
189 
190  // Concatenate the resulting strings into a valid JSON format.
191  absl::StrAppend(output, "{\n");
192  absl::StrAppend(output, "\"model_spec\":", model_spec_output, ",\n");
193  absl::StrAppend(output, "\"metadata\": {");
194  absl::StrAppend(output, "\"signature_def\": ", signature_def_output, "}\n");
195  absl::StrAppend(output, "}\n");
196 
197  return absl::OkStatus();
198 }
199 
200 } // namespace serving
201 } // namespace tensorflow