TensorFlow Serving C++ API Documentation
http_server.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/http_server.h"
17 
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/string_view.h"
27 #include "re2/re2.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow_serving/model_servers/http_rest_api_handler.h"
31 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
32 #include "tensorflow_serving/model_servers/server_core.h"
33 #include "tensorflow_serving/model_servers/server_init.h"
34 #include "tensorflow_serving/servables/tensorflow/util.h"
35 #include "tensorflow_serving/util/net_http/public/response_code_enum.h"
36 #include "tensorflow_serving/util/net_http/server/public/httpserver.h"
37 #include "tensorflow_serving/util/net_http/server/public/server_request_interface.h"
38 #include "tensorflow_serving/util/prometheus_exporter.h"
39 #include "tensorflow_serving/util/threadpool_executor.h"
40 
41 namespace tensorflow {
42 namespace serving {
43 
44 namespace {
45 
46 net_http::HTTPStatusCode ToHTTPStatusCode(const Status& status) {
47  using error::Code;
48  using net_http::HTTPStatusCode;
49  switch (static_cast<absl::StatusCode>(status.code())) {
50  case absl::StatusCode::kOk:
51  return HTTPStatusCode::OK;
52  case absl::StatusCode::kCancelled:
53  return HTTPStatusCode::CLIENT_CLOSED_REQUEST;
54  case absl::StatusCode::kUnknown:
55  return HTTPStatusCode::ERROR;
56  case absl::StatusCode::kInvalidArgument:
57  return HTTPStatusCode::BAD_REQUEST;
58  case absl::StatusCode::kDeadlineExceeded:
59  return HTTPStatusCode::GATEWAY_TO;
60  case absl::StatusCode::kNotFound:
61  return HTTPStatusCode::NOT_FOUND;
62  case absl::StatusCode::kAlreadyExists:
63  return HTTPStatusCode::CONFLICT;
64  case absl::StatusCode::kPermissionDenied:
65  return HTTPStatusCode::FORBIDDEN;
66  case absl::StatusCode::kResourceExhausted:
67  return HTTPStatusCode::TOO_MANY_REQUESTS;
68  case absl::StatusCode::kFailedPrecondition:
69  return HTTPStatusCode::BAD_REQUEST;
70  case absl::StatusCode::kAborted:
71  return HTTPStatusCode::CONFLICT;
72  case absl::StatusCode::kOutOfRange:
73  return HTTPStatusCode::BAD_REQUEST;
74  case absl::StatusCode::kUnimplemented:
75  return HTTPStatusCode::NOT_IMP;
76  case absl::StatusCode::kInternal:
77  return HTTPStatusCode::ERROR;
78  case absl::StatusCode::kUnavailable:
79  return HTTPStatusCode::SERVICE_UNAV;
80  case absl::StatusCode::kDataLoss:
81  return HTTPStatusCode::ERROR;
82  case absl::StatusCode::kUnauthenticated:
83  return HTTPStatusCode::UNAUTHORIZED;
84  default:
85  return HTTPStatusCode::ERROR;
86  }
87 }
88 
89 void ProcessPrometheusRequest(PrometheusExporter* exporter, const string& path,
90  net_http::ServerRequestInterface* req) {
91  std::vector<std::pair<string, string>> headers;
92  headers.push_back({"Content-Type", "text/plain"});
93  string output;
94  Status status;
95  // Check if url matches the path.
96  if (req->uri_path() != path) {
97  output = absl::StrFormat("Unexpected path: %s. Should be %s",
98  req->uri_path(), path);
99  status = Status(static_cast<tensorflow::errors::Code>(
100  absl::StatusCode::kInvalidArgument),
101  output);
102  } else {
103  status = exporter->GeneratePage(&output);
104  }
105  const net_http::HTTPStatusCode http_status = ToHTTPStatusCode(status);
106  // Note: we add headers+output for non successful status too, in case the
107  // output contains details about the error (e.g. error messages).
108  for (const auto& kv : headers) {
109  req->OverwriteResponseHeader(kv.first, kv.second);
110  }
111  req->WriteResponseString(output);
112  if (http_status != net_http::HTTPStatusCode::OK) {
113  VLOG(1) << "Error Processing prometheus metrics request. Error: "
114  << status.ToString();
115  }
116  req->ReplyWithStatus(http_status);
117 }
118 
119 class RequestExecutor final : public net_http::EventExecutor {
120  public:
121  explicit RequestExecutor(int num_threads)
122  : executor_(Env::Default(), "httprestserver", num_threads) {}
123 
124  void Schedule(std::function<void()> fn) override { executor_.Schedule(fn); }
125 
126  private:
127  ThreadPoolExecutor executor_;
128 };
129 
130 class RestApiRequestDispatcher {
131  public:
132  RestApiRequestDispatcher(int timeout_in_ms, ServerCore* core)
133  : regex_(HttpRestApiHandler::kPathRegex), core_(core) {
134  auto* tf_serving_registry = tensorflow::serving::init::
135  TensorflowServingFunctionRegistration::GetRegistry();
136  handler_ =
137  tf_serving_registry->GetCreateHttpRestApiHandler()(timeout_in_ms, core);
138  }
139 
140  net_http::RequestHandler Dispatch(net_http::ServerRequestInterface* req) {
141  if (RE2::FullMatch(string(req->uri_path()), regex_)) {
142  return [this](net_http::ServerRequestInterface* req) {
143  this->ProcessRequest(req);
144  };
145  }
146  VLOG(1) << "Ignoring HTTP request: " << req->http_method() << " "
147  << req->uri_path();
148  return nullptr;
149  }
150 
151  private:
152  void ProcessRequest(net_http::ServerRequestInterface* req) {
153  const uint64_t start = Env::Default()->NowMicros();
154  string body;
155  int64_t num_bytes = 0;
156  auto request_chunk = req->ReadRequestBytes(&num_bytes);
157  while (request_chunk != nullptr) {
158  absl::StrAppend(&body, absl::string_view(request_chunk.get(), num_bytes));
159  request_chunk = req->ReadRequestBytes(&num_bytes);
160  }
161 
162  std::vector<std::pair<string, string>> headers;
163  string model_name;
164  string method;
165  string output;
166  VLOG(1) << "Processing HTTP request: " << req->http_method() << " "
167  << req->uri_path() << " body: " << body.size() << " bytes.";
168 
169  Status status;
170  if (req->http_method() == "OPTIONS") {
171  absl::string_view origin_header = req->GetRequestHeader("Origin");
172  if (RE2::PartialMatch(origin_header, "https?://")) {
173  status = absl::OkStatus();
174  } else {
175  status = errors::FailedPrecondition(
176  "Origin header is missing in CORS preflight");
177  }
178  } else {
179  status =
180  handler_->ProcessRequest(req->http_method(), req->uri_path(), body,
181  &headers, &model_name, &method, &output);
182  }
183  if (core_->enable_cors_support()) {
184  AddCORSHeaders(&headers);
185  }
186 
187  const auto http_status = ToHTTPStatusCode(status);
188  // Note: we add headers+output for non successful status too, in case the
189  // output contains details about the error (e.g. error messages).
190  for (const auto& kv : headers) {
191  req->OverwriteResponseHeader(kv.first, kv.second);
192  }
193  req->WriteResponseString(output);
194  if (http_status == net_http::HTTPStatusCode::OK) {
195  RecordRequestLatency(model_name, /*api=*/method, /*entrypoint=*/"REST",
196  Env::Default()->NowMicros() - start);
197  } else {
198  VLOG(1) << "Error Processing HTTP/REST request: " << req->http_method()
199  << " " << req->uri_path() << " Error: " << status.ToString();
200  }
201  RecordModelRequestCount(model_name, status);
202  req->ReplyWithStatus(http_status);
203  }
204 
205  const RE2 regex_;
206  ServerCore* core_;
207  std::unique_ptr<HttpRestApiHandlerBase> handler_;
208 };
209 
210 } // namespace
211 
212 std::unique_ptr<net_http::HTTPServerInterface> CreateAndStartHttpServer(
213  int port, int num_threads, int timeout_in_ms,
214  const MonitoringConfig& monitoring_config, ServerCore* core) {
215  auto options = absl::make_unique<net_http::ServerOptions>();
216  options->AddPort(static_cast<uint32_t>(port));
217  options->SetExecutor(absl::make_unique<RequestExecutor>(num_threads));
218 
219  auto server = net_http::CreateEvHTTPServer(std::move(options));
220  if (server == nullptr) {
221  return nullptr;
222  }
223 
224  // Register handler for prometheus metric endpoint.
225  if (monitoring_config.prometheus_config().enable()) {
226  std::shared_ptr<PrometheusExporter> exporter =
227  std::make_shared<PrometheusExporter>();
228  net_http::RequestHandlerOptions prometheus_request_options;
229  PrometheusConfig prometheus_config = monitoring_config.prometheus_config();
230  auto path = prometheus_config.path().empty()
231  ? PrometheusExporter::kPrometheusPath
232  : prometheus_config.path();
233  server->RegisterRequestHandler(
234  path,
235  [exporter, path](net_http::ServerRequestInterface* req) {
236  ProcessPrometheusRequest(exporter.get(), path, req);
237  },
238  prometheus_request_options);
239  }
240 
241  std::shared_ptr<RestApiRequestDispatcher> dispatcher =
242  std::make_shared<RestApiRequestDispatcher>(timeout_in_ms, core);
243  net_http::RequestHandlerOptions handler_options;
244  server->RegisterRequestDispatcher(
245  [dispatcher](net_http::ServerRequestInterface* req) {
246  return dispatcher->Dispatch(req);
247  },
248  handler_options);
249  if (server->StartAcceptingRequests()) {
250  return server;
251  }
252  return nullptr;
253 }
254 
255 } // namespace serving
256 } // namespace tensorflow