16 #include "tensorflow_serving/model_servers/http_rest_api_handler.h"
22 #include "google/protobuf/any.pb.h"
23 #include "google/protobuf/arena.h"
24 #include "google/protobuf/util/json_util.h"
25 #include "absl/strings/escaping.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_replace.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/time/time.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/cc/saved_model/signature_constants.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/threadpool_options.h"
36 #include "tensorflow_serving/apis/model.pb.h"
37 #include "tensorflow_serving/apis/predict.pb.h"
38 #include "tensorflow_serving/core/servable_handle.h"
39 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
40 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
41 #include "tensorflow_serving/model_servers/server_core.h"
42 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
43 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
44 #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
45 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
46 #include "tensorflow_serving/util/json_tensor.h"
48 namespace tensorflow {
54 const char*
const HttpRestApiHandler::kPathRegex = kHTTPRestApiHandlerPathRegex;
56 HttpRestApiHandler::HttpRestApiHandler(
int timeout_in_ms, ServerCore* core)
57 : run_options_(), core_(core), predictor_(new TensorflowPredictor()) {
58 if (timeout_in_ms > 0) {
59 run_options_.set_timeout_in_ms(timeout_in_ms);
63 HttpRestApiHandler::~HttpRestApiHandler() {}
65 Status HttpRestApiHandler::ProcessRequest(
66 const absl::string_view http_method,
const absl::string_view request_path,
67 const absl::string_view request_body,
68 std::vector<std::pair<string, string>>* headers,
string* model_name,
69 string* method,
string* output) {
73 string model_subresource;
74 Status status = errors::InvalidArgument(
"Malformed request: ", http_method,
76 absl::optional<int64_t> model_version;
77 absl::optional<string> model_version_label;
78 bool parse_successful;
80 TF_RETURN_IF_ERROR(ParseModelInfo(
81 http_method, request_path, model_name, &model_version,
82 &model_version_label, method, &model_subresource, &parse_successful));
85 if (http_method ==
"POST" && parse_successful) {
86 if (*method ==
"classify") {
88 ProcessClassifyRequest(*model_name, model_version,
89 model_version_label, request_body, output);
90 }
else if (*method ==
"regress") {
91 status = ProcessRegressRequest(*model_name, model_version,
92 model_version_label, request_body, output);
93 }
else if (*method ==
"predict") {
94 status = ProcessPredictRequest(*model_name, model_version,
95 model_version_label, request_body, output);
97 }
else if (http_method ==
"GET" && parse_successful) {
98 if (!model_subresource.empty() && model_subresource ==
"metadata") {
99 status = ProcessModelMetadataRequest(*model_name, model_version,
100 model_version_label, output);
102 status = ProcessModelStatusRequest(*model_name, model_version,
103 model_version_label, output);
107 MakeJsonFromStatus(status, output);
111 Status HttpRestApiHandler::ProcessClassifyRequest(
112 const absl::string_view model_name,
113 const absl::optional<int64_t>& model_version,
114 const absl::optional<absl::string_view>& model_version_label,
115 const absl::string_view request_body,
string* output) {
116 ::google::protobuf::Arena arena;
118 auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
119 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
120 model_name, model_version, model_version_label,
121 request->mutable_model_spec()));
122 TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));
124 auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
125 TF_RETURN_IF_ERROR(TensorflowClassificationServiceImpl::Classify(
126 run_options_, core_, thread::ThreadPoolOptions(), *request, response));
128 MakeJsonFromClassificationResult(response->result(), output));
129 return absl::OkStatus();
132 Status HttpRestApiHandler::ProcessRegressRequest(
133 const absl::string_view model_name,
134 const absl::optional<int64_t>& model_version,
135 const absl::optional<absl::string_view>& model_version_label,
136 const absl::string_view request_body,
string* output) {
137 ::google::protobuf::Arena arena;
139 auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
140 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
141 model_name, model_version, model_version_label,
142 request->mutable_model_spec()));
143 TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));
145 auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
146 TF_RETURN_IF_ERROR(TensorflowRegressionServiceImpl::Regress(
147 run_options_, core_, thread::ThreadPoolOptions(), *request, response));
148 TF_RETURN_IF_ERROR(MakeJsonFromRegressionResult(response->result(), output));
149 return absl::OkStatus();
152 Status HttpRestApiHandler::ProcessPredictRequest(
153 const absl::string_view model_name,
154 const absl::optional<int64_t>& model_version,
155 const absl::optional<absl::string_view>& model_version_label,
156 const absl::string_view request_body,
string* output) {
157 ::google::protobuf::Arena arena;
159 auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
160 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
161 model_name, model_version, model_version_label,
162 request->mutable_model_spec()));
164 JsonPredictRequestFormat format;
165 TF_RETURN_IF_ERROR(FillPredictRequestFromJson(
167 [
this, request](
const string& sig,
168 ::google::protobuf::Map<string, TensorInfo>* map) {
169 return this->GetInfoMap(request->model_spec(), sig, map);
173 auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);
175 predictor_->Predict(run_options_, core_, *request, response));
176 TF_RETURN_IF_ERROR(MakeJsonFromTensors(response->outputs(), format, output));
177 return absl::OkStatus();
180 Status HttpRestApiHandler::ProcessModelStatusRequest(
181 const absl::string_view model_name,
182 const absl::optional<int64_t>& model_version,
183 const absl::optional<absl::string_view>& model_version_label,
187 if (model_name.empty()) {
188 return errors::InvalidArgument(
"Missing model name in request.");
191 ::google::protobuf::Arena arena;
193 auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
194 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
195 model_name, model_version, model_version_label,
196 request->mutable_model_spec()));
198 auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
200 GetModelStatusImpl::GetModelStatus(core_, *request, response));
201 return ToJsonString(*response, output);
204 Status HttpRestApiHandler::ProcessModelMetadataRequest(
205 const absl::string_view model_name,
206 const absl::optional<int64_t>& model_version,
207 const absl::optional<absl::string_view>& model_version_label,
209 if (model_name.empty()) {
210 return errors::InvalidArgument(
"Missing model name in request.");
213 ::google::protobuf::Arena arena;
215 auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
217 request->add_metadata_field(GetModelMetadataImpl::kSignatureDef);
218 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
219 model_name, model_version, model_version_label,
220 request->mutable_model_spec()));
222 auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
224 GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
225 return ToJsonString(*response, output);
228 Status HttpRestApiHandler::GetInfoMap(
229 const ModelSpec& model_spec,
const string& signature_name,
230 ::google::protobuf::Map<string, tensorflow::TensorInfo>* infomap) {
231 ServableHandle<SavedModelBundle> bundle;
232 TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &bundle));
233 const string& signame =
234 signature_name.empty() ? kDefaultServingSignatureDefKey : signature_name;
235 auto iter = bundle->meta_graph_def.signature_def().find(signame);
236 if (iter == bundle->meta_graph_def.signature_def().end()) {
237 return errors::InvalidArgument(
"Serving signature name: \"", signame,
238 "\" not found in signature def");
240 *infomap = iter->second.inputs();
241 return absl::OkStatus();