TensorFlow Serving C++ API Documentation
tfrt_http_rest_api_handler.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/tfrt_http_rest_api_handler.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/any.pb.h"
23 #include "absl/status/status.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/time/clock.h"
30 #include "absl/time/time.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/cc/saved_model/signature_constants.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/protobuf/meta_graph.pb.h"
36 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
37 #include "tsl/platform/errors.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/predict.pb.h"
40 #include "tensorflow_serving/core/servable_handle.h"
41 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
42 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
43 #include "tensorflow_serving/model_servers/server_core.h"
44 #include "tensorflow_serving/servables/tensorflow/servable.h"
45 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
46 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
47 #include "tensorflow_serving/util/json_tensor.h"
48 
49 namespace tensorflow {
50 namespace serving {
51 
53 
54 const char* const TFRTHttpRestApiHandler::kPathRegex =
55  kHTTPRestApiHandlerPathRegex;
56 
57 TFRTHttpRestApiHandler::TFRTHttpRestApiHandler(int timeout_in_ms,
58  ServerCore* core)
59  : run_options_(),
60  timeout_(absl::Milliseconds(timeout_in_ms)),
61  core_(core) {}
62 
63 TFRTHttpRestApiHandler::~TFRTHttpRestApiHandler() {}
64 
65 Status TFRTHttpRestApiHandler::ProcessRequest(
66  const absl::string_view http_method, const absl::string_view request_path,
67  const absl::string_view request_body,
68  std::vector<std::pair<std::string, std::string>>* headers,
69  std::string* model_name, std::string* method, std::string* output) {
70  headers->clear();
71  output->clear();
72  AddHeaders(headers);
73 
74  std::string model_subresource;
75  Status status = errors::InvalidArgument("Malformed request: ", http_method,
76  " ", request_path);
77  absl::optional<int64_t> model_version;
78  absl::optional<std::string> model_version_label;
79  bool parse_successful;
80 
81  TF_RETURN_IF_ERROR(ParseModelInfo(
82  http_method, request_path, model_name, &model_version,
83  &model_version_label, method, &model_subresource, &parse_successful));
84 
85  auto run_options = run_options_;
86  run_options.deadline = absl::Now() + timeout_;
87 
88  // Dispatch request to appropriate processor
89  if (http_method == "POST" && parse_successful) {
90  if (*method == "classify") {
91  status = ProcessClassifyRequest(*model_name, model_version,
92  model_version_label, request_body,
93  run_options, output);
94  } else if (*method == "regress") {
95  status =
96  ProcessRegressRequest(*model_name, model_version, model_version_label,
97  request_body, run_options, output);
98  } else if (*method == "predict") {
99  status =
100  ProcessPredictRequest(*model_name, model_version, model_version_label,
101  request_body, run_options, output);
102  }
103  } else if (http_method == "GET" && parse_successful) {
104  if (!model_subresource.empty() && model_subresource == "metadata") {
105  status = ProcessModelMetadataRequest(*model_name, model_version,
106  model_version_label, output);
107  } else {
108  status = ProcessModelStatusRequest(
109  *model_name, model_version, model_version_label, run_options, output);
110  }
111  }
112 
113  MakeJsonFromStatus(status, output);
114  return status;
115 }
116 
117 Status TFRTHttpRestApiHandler::ProcessClassifyRequest(
118  const absl::string_view model_name,
119  const absl::optional<int64_t>& model_version,
120  const absl::optional<absl::string_view>& model_version_label,
121  const absl::string_view request_body,
122  const Servable::RunOptions& run_options, std::string* output) {
123  ::google::protobuf::Arena arena;
124 
125  auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
126  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
127  model_name, model_version, model_version_label,
128  request->mutable_model_spec()));
129  TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));
130 
131  auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
132  ServableHandle<Servable> servable;
133  TF_RETURN_IF_ERROR(
134  core_->GetServableHandle(request->model_spec(), &servable));
135  TF_RETURN_IF_ERROR(servable->Classify(run_options, *request, response));
136  TF_RETURN_IF_ERROR(
137  MakeJsonFromClassificationResult(response->result(), output));
138  return Status();
139 }
140 
141 Status TFRTHttpRestApiHandler::ProcessRegressRequest(
142  const absl::string_view model_name,
143  const absl::optional<int64_t>& model_version,
144  const absl::optional<absl::string_view>& model_version_label,
145  const absl::string_view request_body,
146  const Servable::RunOptions& run_options, std::string* output) {
147  ::google::protobuf::Arena arena;
148 
149  auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
150  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
151  model_name, model_version, model_version_label,
152  request->mutable_model_spec()));
153  TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));
154 
155  auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
156  ServableHandle<Servable> servable;
157  TF_RETURN_IF_ERROR(
158  core_->GetServableHandle(request->model_spec(), &servable));
159  TF_RETURN_IF_ERROR(servable->Regress(run_options, *request, response));
160  return MakeJsonFromRegressionResult(response->result(), output);
161 }
162 
163 Status TFRTHttpRestApiHandler::ProcessPredictRequest(
164  const absl::string_view model_name,
165  const absl::optional<int64_t>& model_version,
166  const absl::optional<absl::string_view>& model_version_label,
167  const absl::string_view request_body,
168  const Servable::RunOptions& run_options, std::string* output) {
169  ::google::protobuf::Arena arena;
170 
171  auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
172  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
173  model_name, model_version, model_version_label,
174  request->mutable_model_spec()));
175 
176  JsonPredictRequestFormat format;
177  TF_RETURN_IF_ERROR(FillPredictRequestFromJson(
178  request_body,
179  [this, request](const std::string& sig,
180  ::google::protobuf::Map<std::string, TensorInfo>* map) {
181  return this->GetInfoMap(request->model_spec(), sig, map);
182  },
183  request, &format));
184 
185  auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);
186 
187  ServableHandle<Servable> servable;
188  TF_RETURN_IF_ERROR(
189  core_->GetServableHandle(request->model_spec(), &servable));
190  TF_RETURN_IF_ERROR(servable->Predict(run_options, *request, response));
191 
192  TF_RETURN_IF_ERROR(MakeJsonFromTensors(response->outputs(), format, output));
193  return Status();
194 }
195 
196 Status TFRTHttpRestApiHandler::ProcessModelStatusRequest(
197  const absl::string_view model_name,
198  const absl::optional<int64_t>& model_version,
199  const absl::optional<absl::string_view>& model_version_label,
200  const Servable::RunOptions& run_options, std::string* output) {
201  // We do not yet support returning status of all models
202  // to be in-sync with the gRPC GetModelStatus API.
203  if (model_name.empty()) {
204  return errors::InvalidArgument("Missing model name in request.");
205  }
206 
207  ::google::protobuf::Arena arena;
208 
209  auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
210  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
211  model_name, model_version, model_version_label,
212  request->mutable_model_spec()));
213 
214  auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
215  TF_RETURN_IF_ERROR(
216  GetModelStatusImpl::GetModelStatus(core_, *request, response));
217  return ToJsonString(*response, output);
218 }
219 
220 Status TFRTHttpRestApiHandler::ProcessModelMetadataRequest(
221  const absl::string_view model_name,
222  const absl::optional<int64_t>& model_version,
223  const absl::optional<absl::string_view>& model_version_label,
224  std::string* output) {
225  if (model_name.empty()) {
226  return errors::InvalidArgument("Missing model name in request.");
227  }
228 
229  ::google::protobuf::Arena arena;
230 
231  auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
232  // We currently only support the kSignatureDef metadata field
233  request->add_metadata_field(std::string(kSignatureDef));
234  TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
235  model_name, model_version, model_version_label,
236  request->mutable_model_spec()));
237 
238  auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
239  TF_RETURN_IF_ERROR(
240  TFRTGetModelMetadataImpl::GetModelMetadata(core_, *request, response));
241 
242  return ToJsonString(*response, output);
243 }
244 
245 Status TFRTHttpRestApiHandler::GetInfoMap(
246  const ModelSpec& model_spec, const std::string& signature_name,
247  ::google::protobuf::Map<std::string, tensorflow::TensorInfo>* infomap) {
248  ServableHandle<Servable> servable;
249  TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &servable));
250  auto& saved_model =
251  down_cast<TfrtSavedModelServable*>(servable.get())->saved_model();
252  const std::string& signame =
253  signature_name.empty() ? kDefaultServingSignatureDefKey : signature_name;
254  auto iter = saved_model.GetMetaGraphDef().signature_def().find(signame);
255  if (iter == saved_model.GetMetaGraphDef().signature_def().end()) {
256  return errors::InvalidArgument("Serving signature name: \"", signame,
257  "\" not found in signature def");
258  }
259  *infomap = iter->second.inputs();
260  return Status();
261 }
262 
263 } // namespace serving
264 } // namespace tensorflow