TensorFlow Serving C++ API Documentation
http_rest_api_handler.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/http_rest_api_handler.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/any.pb.h"
23 #include "google/protobuf/arena.h"
24 #include "google/protobuf/util/json_util.h"
25 #include "absl/strings/escaping.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_replace.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/time/time.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/cc/saved_model/signature_constants.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/threadpool_options.h"
36 #include "tensorflow_serving/apis/model.pb.h"
37 #include "tensorflow_serving/apis/predict.pb.h"
38 #include "tensorflow_serving/core/servable_handle.h"
39 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
40 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
41 #include "tensorflow_serving/model_servers/server_core.h"
42 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
43 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
44 #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
45 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
46 #include "tensorflow_serving/util/json_tensor.h"
47 
48 namespace tensorflow {
49 namespace serving {
50 
53 
54 const char* const HttpRestApiHandler::kPathRegex = kHTTPRestApiHandlerPathRegex;
55 
56 HttpRestApiHandler::HttpRestApiHandler(int timeout_in_ms, ServerCore* core)
57  : run_options_(), core_(core), predictor_(new TensorflowPredictor()) {
58  if (timeout_in_ms > 0) {
59  run_options_.set_timeout_in_ms(timeout_in_ms);
60  }
61 }
62 
63 HttpRestApiHandler::~HttpRestApiHandler() {}
64 
65 Status HttpRestApiHandler::ProcessRequest(
66  const absl::string_view http_method, const absl::string_view request_path,
67  const absl::string_view request_body,
68  std::vector<std::pair<string, string>>* headers, string* model_name,
69  string* method, string* output) {
70  headers->clear();
71  output->clear();
72  AddHeaders(headers);
73  string model_subresource;
74  Status status = errors::InvalidArgument("Malformed request: ", http_method,
75  " ", request_path);
76  absl::optional<int64_t> model_version;
77  absl::optional<string> model_version_label;
78  bool parse_successful;
79 
80  TF_RETURN_IF_ERROR(ParseModelInfo(
81  http_method, request_path, model_name, &model_version,
82  &model_version_label, method, &model_subresource, &parse_successful));
83 
84  // Dispatch request to appropriate processor
85  if (http_method == "POST" && parse_successful) {
86  if (*method == "classify") {
87  status =
88  ProcessClassifyRequest(*model_name, model_version,
89  model_version_label, request_body, output);
90  } else if (*method == "regress") {
91  status = ProcessRegressRequest(*model_name, model_version,
92  model_version_label, request_body, output);
93  } else if (*method == "predict") {
94  status = ProcessPredictRequest(*model_name, model_version,
95  model_version_label, request_body, output);
96  }
97  } else if (http_method == "GET" && parse_successful) {
98  if (!model_subresource.empty() && model_subresource == "metadata") {
99  status = ProcessModelMetadataRequest(*model_name, model_version,
100  model_version_label, output);
101  } else {
102  status = ProcessModelStatusRequest(*model_name, model_version,
103  model_version_label, output);
104  }
105  }
106 
107  MakeJsonFromStatus(status, output);
108  return status;
109 }
110 
111 Status HttpRestApiHandler::ProcessClassifyRequest(
112  const absl::string_view model_name,
113  const absl::optional<int64_t>& model_version,
114  const absl::optional<absl::string_view>& model_version_label,
115  const absl::string_view request_body, string* output) {
116  ::google::protobuf::Arena arena;
117 
118  auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
119  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
120  model_name, model_version, model_version_label,
121  request->mutable_model_spec()));
122  TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));
123 
124  auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
125  TF_RETURN_IF_ERROR(TensorflowClassificationServiceImpl::Classify(
126  run_options_, core_, thread::ThreadPoolOptions(), *request, response));
127  TF_RETURN_IF_ERROR(
128  MakeJsonFromClassificationResult(response->result(), output));
129  return absl::OkStatus();
130 }
131 
132 Status HttpRestApiHandler::ProcessRegressRequest(
133  const absl::string_view model_name,
134  const absl::optional<int64_t>& model_version,
135  const absl::optional<absl::string_view>& model_version_label,
136  const absl::string_view request_body, string* output) {
137  ::google::protobuf::Arena arena;
138 
139  auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
140  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
141  model_name, model_version, model_version_label,
142  request->mutable_model_spec()));
143  TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));
144 
145  auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
146  TF_RETURN_IF_ERROR(TensorflowRegressionServiceImpl::Regress(
147  run_options_, core_, thread::ThreadPoolOptions(), *request, response));
148  TF_RETURN_IF_ERROR(MakeJsonFromRegressionResult(response->result(), output));
149  return absl::OkStatus();
150 }
151 
152 Status HttpRestApiHandler::ProcessPredictRequest(
153  const absl::string_view model_name,
154  const absl::optional<int64_t>& model_version,
155  const absl::optional<absl::string_view>& model_version_label,
156  const absl::string_view request_body, string* output) {
157  ::google::protobuf::Arena arena;
158 
159  auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
160  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
161  model_name, model_version, model_version_label,
162  request->mutable_model_spec()));
163 
164  JsonPredictRequestFormat format;
165  TF_RETURN_IF_ERROR(FillPredictRequestFromJson(
166  request_body,
167  [this, request](const string& sig,
168  ::google::protobuf::Map<string, TensorInfo>* map) {
169  return this->GetInfoMap(request->model_spec(), sig, map);
170  },
171  request, &format));
172 
173  auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);
174  TF_RETURN_IF_ERROR(
175  predictor_->Predict(run_options_, core_, *request, response));
176  TF_RETURN_IF_ERROR(MakeJsonFromTensors(response->outputs(), format, output));
177  return absl::OkStatus();
178 }
179 
180 Status HttpRestApiHandler::ProcessModelStatusRequest(
181  const absl::string_view model_name,
182  const absl::optional<int64_t>& model_version,
183  const absl::optional<absl::string_view>& model_version_label,
184  string* output) {
185  // We do not yet support returning status of all models
186  // to be in-sync with the gRPC GetModelStatus API.
187  if (model_name.empty()) {
188  return errors::InvalidArgument("Missing model name in request.");
189  }
190 
191  ::google::protobuf::Arena arena;
192 
193  auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
194  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
195  model_name, model_version, model_version_label,
196  request->mutable_model_spec()));
197 
198  auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
199  TF_RETURN_IF_ERROR(
200  GetModelStatusImpl::GetModelStatus(core_, *request, response));
201  return ToJsonString(*response, output);
202 }
203 
204 Status HttpRestApiHandler::ProcessModelMetadataRequest(
205  const absl::string_view model_name,
206  const absl::optional<int64_t>& model_version,
207  const absl::optional<absl::string_view>& model_version_label,
208  string* output) {
209  if (model_name.empty()) {
210  return errors::InvalidArgument("Missing model name in request.");
211  }
212 
213  ::google::protobuf::Arena arena;
214 
215  auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
216  // We currently only support the kSignatureDef metadata field
217  request->add_metadata_field(GetModelMetadataImpl::kSignatureDef);
218  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
219  model_name, model_version, model_version_label,
220  request->mutable_model_spec()));
221 
222  auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
223  TF_RETURN_IF_ERROR(
224  GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
225  return ToJsonString(*response, output);
226 }
227 
228 Status HttpRestApiHandler::GetInfoMap(
229  const ModelSpec& model_spec, const string& signature_name,
230  ::google::protobuf::Map<string, tensorflow::TensorInfo>* infomap) {
231  ServableHandle<SavedModelBundle> bundle;
232  TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &bundle));
233  const string& signame =
234  signature_name.empty() ? kDefaultServingSignatureDefKey : signature_name;
235  auto iter = bundle->meta_graph_def.signature_def().find(signame);
236  if (iter == bundle->meta_graph_def.signature_def().end()) {
237  return errors::InvalidArgument("Serving signature name: \"", signame,
238  "\" not found in signature def");
239  }
240  *infomap = iter->second.inputs();
241  return absl::OkStatus();
242 }
243 
244 } // namespace serving
245 } // namespace tensorflow