16 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
21 #include "google/protobuf/util/json_util.h"
22 #include "absl/strings/numbers.h"
23 #include <curl/curl.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
28 namespace tensorflow {
31 const char*
const kPredictionApiGegex =
32 R
"((?i)/v1/models/([^/:]+)(?:(?:/versions/(\d+))|(?:/labels/([^/:]+)))?:(classify|regress|predict))";
33 const char*
const kModelStatusApiRegex =
34 R
"((?i)/v1/models(?:/([^/:]+))?(?:(?:/versions/(\d+))|(?:/labels/([^/:]+)))?(?:\/(metadata))?)";
36 void AddHeaders(std::vector<std::pair<string, string>>* headers) {
37 headers->push_back({
"Content-Type",
"application/json"});
40 void AddCORSHeaders(std::vector<std::pair<string, string>>* headers) {
41 headers->push_back({
"Access-Control-Allow-Origin",
"*"});
42 headers->push_back({
"Access-Control-Allow-Methods",
"POST, GET"});
43 headers->push_back({
"Access-Control-Allow-Headers",
"Content-Type"});
46 Status FillModelSpecWithNameVersionAndLabel(
47 const absl::string_view model_name,
48 const absl::optional<int64_t>& model_version,
49 const absl::optional<absl::string_view> model_version_label,
50 ::tensorflow::serving::ModelSpec* model_spec) {
51 model_spec->set_name(
string(model_name));
53 if (model_version.has_value() && model_version_label.has_value()) {
54 return errors::InvalidArgument(
55 "Both model version (", model_version.value(),
56 ") and model version label (", model_version_label.value(),
57 ") cannot be supplied.");
60 if (model_version.has_value()) {
61 model_spec->mutable_version()->set_value(model_version.value());
63 if (model_version_label.has_value()) {
64 model_spec->set_version_label(
string(model_version_label.value()));
66 return absl::OkStatus();
69 bool DecodeArg(
string* arg) {
70 static const bool run_once ABSL_ATTRIBUTE_UNUSED = [&]() {
71 curl_global_init(CURL_GLOBAL_ALL);
74 CURL* curl = curl_easy_init();
75 if (curl !=
nullptr) {
78 curl_easy_unescape(curl, arg->c_str(), arg->size(), &outlength);
79 if (cres ==
nullptr) {
82 arg->assign(cres, outlength);
84 curl_easy_cleanup(curl);
90 Status ParseModelInfo(
const absl::string_view http_method,
91 const absl::string_view request_path,
string* model_name,
92 absl::optional<int64_t>* model_version,
93 absl::optional<string>* model_version_label,
94 string* method,
string* model_subresource,
95 bool* parse_successful) {
96 string model_version_str;
97 string model_version_label_str;
99 if (http_method ==
"POST") {
101 RE2::FullMatch(
string(request_path), kPredictionApiGegex, model_name,
102 &model_version_str, &model_version_label_str, method);
103 }
else if (http_method ==
"GET") {
104 *parse_successful = RE2::FullMatch(
105 string(request_path), kModelStatusApiRegex, model_name,
106 &model_version_str, &model_version_label_str, model_subresource);
108 if (!model_name->empty()) {
109 if (!DecodeArg(model_name)) {
110 return errors::InvalidArgument(
"Failed to decode model name:",
114 if (!model_version_str.empty()) {
116 if (!absl::SimpleAtoi(model_version_str, &version)) {
117 return errors::InvalidArgument(
118 "Failed to convert version: ", model_version_str,
" to numeric.");
120 *model_version = version;
122 if (!model_version_label_str.empty()) {
123 if (!DecodeArg(&model_version_label_str)) {
124 return errors::InvalidArgument(
"Failed to decode model version label:",
125 model_version_label_str);
127 *model_version_label = model_version_label_str;
129 return absl::OkStatus();
132 Status ToJsonString(
const GetModelStatusResponse& response,
string* output) {
133 google::protobuf::util::JsonPrintOptions opts;
134 opts.add_whitespace =
true;
135 opts.always_print_primitive_fields =
true;
137 const auto& status = MessageToJsonString(response, output, opts);
139 return errors::Internal(
"Failed to convert proto to json. Error: ",
142 return absl::OkStatus();
145 Status ToJsonString(
const GetModelMetadataResponse& response,
string* output) {
146 google::protobuf::util::JsonPrintOptions opts;
147 opts.add_whitespace =
true;
148 opts.always_print_primitive_fields =
true;
155 opts.preserve_proto_field_names =
true;
157 string model_spec_output;
158 const auto& status1 =
159 MessageToJsonString(response.model_spec(), &model_spec_output, opts);
161 return errors::Internal(
162 "Failed to convert model spec proto to json. Error: ",
166 tensorflow::serving::SignatureDefMap signature_def_map;
167 if (response.metadata().end() ==
168 response.metadata().find(GetModelMetadataImpl::kSignatureDef)) {
169 return errors::Internal(
170 "Failed to find 'signature_def' key in the GetModelMetadataResponse "
173 bool unpack_status = response.metadata()
174 .at(GetModelMetadataImpl::kSignatureDef)
175 .UnpackTo(&signature_def_map);
176 if (!unpack_status) {
177 return errors::Internal(
178 "Failed to unpack 'Any' object to 'SignatureDefMap'.");
181 string signature_def_output;
182 const auto& status2 =
183 MessageToJsonString(signature_def_map, &signature_def_output, opts);
185 return errors::Internal(
186 "Failed to convert signature def proto to json. Error: ",
191 absl::StrAppend(output,
"{\n");
192 absl::StrAppend(output,
"\"model_spec\":", model_spec_output,
",\n");
193 absl::StrAppend(output,
"\"metadata\": {");
194 absl::StrAppend(output,
"\"signature_def\": ", signature_def_output,
"}\n");
195 absl::StrAppend(output,
"}\n");
197 return absl::OkStatus();