16 #include "tensorflow_serving/model_servers/http_server.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow_serving/model_servers/http_rest_api_handler.h"
31 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
32 #include "tensorflow_serving/model_servers/server_core.h"
33 #include "tensorflow_serving/model_servers/server_init.h"
34 #include "tensorflow_serving/servables/tensorflow/util.h"
35 #include "tensorflow_serving/util/net_http/public/response_code_enum.h"
36 #include "tensorflow_serving/util/net_http/server/public/httpserver.h"
37 #include "tensorflow_serving/util/net_http/server/public/server_request_interface.h"
38 #include "tensorflow_serving/util/prometheus_exporter.h"
39 #include "tensorflow_serving/util/threadpool_executor.h"
41 namespace tensorflow {
46 net_http::HTTPStatusCode ToHTTPStatusCode(
const Status& status) {
48 using net_http::HTTPStatusCode;
49 switch (
static_cast<absl::StatusCode
>(status.code())) {
50 case absl::StatusCode::kOk:
51 return HTTPStatusCode::OK;
52 case absl::StatusCode::kCancelled:
53 return HTTPStatusCode::CLIENT_CLOSED_REQUEST;
54 case absl::StatusCode::kUnknown:
55 return HTTPStatusCode::ERROR;
56 case absl::StatusCode::kInvalidArgument:
57 return HTTPStatusCode::BAD_REQUEST;
58 case absl::StatusCode::kDeadlineExceeded:
59 return HTTPStatusCode::GATEWAY_TO;
60 case absl::StatusCode::kNotFound:
61 return HTTPStatusCode::NOT_FOUND;
62 case absl::StatusCode::kAlreadyExists:
63 return HTTPStatusCode::CONFLICT;
64 case absl::StatusCode::kPermissionDenied:
65 return HTTPStatusCode::FORBIDDEN;
66 case absl::StatusCode::kResourceExhausted:
67 return HTTPStatusCode::TOO_MANY_REQUESTS;
68 case absl::StatusCode::kFailedPrecondition:
69 return HTTPStatusCode::BAD_REQUEST;
70 case absl::StatusCode::kAborted:
71 return HTTPStatusCode::CONFLICT;
72 case absl::StatusCode::kOutOfRange:
73 return HTTPStatusCode::BAD_REQUEST;
74 case absl::StatusCode::kUnimplemented:
75 return HTTPStatusCode::NOT_IMP;
76 case absl::StatusCode::kInternal:
77 return HTTPStatusCode::ERROR;
78 case absl::StatusCode::kUnavailable:
79 return HTTPStatusCode::SERVICE_UNAV;
80 case absl::StatusCode::kDataLoss:
81 return HTTPStatusCode::ERROR;
82 case absl::StatusCode::kUnauthenticated:
83 return HTTPStatusCode::UNAUTHORIZED;
85 return HTTPStatusCode::ERROR;
89 void ProcessPrometheusRequest(PrometheusExporter* exporter,
const string& path,
90 net_http::ServerRequestInterface* req) {
91 std::vector<std::pair<string, string>> headers;
92 headers.push_back({
"Content-Type",
"text/plain"});
96 if (req->uri_path() != path) {
97 output = absl::StrFormat(
"Unexpected path: %s. Should be %s",
98 req->uri_path(), path);
99 status = Status(
static_cast<tensorflow::errors::Code
>(
100 absl::StatusCode::kInvalidArgument),
103 status = exporter->GeneratePage(&output);
105 const net_http::HTTPStatusCode http_status = ToHTTPStatusCode(status);
108 for (
const auto& kv : headers) {
109 req->OverwriteResponseHeader(kv.first, kv.second);
111 req->WriteResponseString(output);
112 if (http_status != net_http::HTTPStatusCode::OK) {
113 VLOG(1) <<
"Error Processing prometheus metrics request. Error: "
114 << status.ToString();
116 req->ReplyWithStatus(http_status);
119 class RequestExecutor final :
public net_http::EventExecutor {
121 explicit RequestExecutor(
int num_threads)
122 : executor_(Env::Default(),
"httprestserver", num_threads) {}
124 void Schedule(std::function<
void()> fn)
override { executor_.Schedule(fn); }
127 ThreadPoolExecutor executor_;
130 class RestApiRequestDispatcher {
132 RestApiRequestDispatcher(
int timeout_in_ms, ServerCore* core)
133 : regex_(HttpRestApiHandler::kPathRegex), core_(core) {
134 auto* tf_serving_registry = tensorflow::serving::init::
135 TensorflowServingFunctionRegistration::GetRegistry();
137 tf_serving_registry->GetCreateHttpRestApiHandler()(timeout_in_ms, core);
140 net_http::RequestHandler Dispatch(net_http::ServerRequestInterface* req) {
141 if (RE2::FullMatch(
string(req->uri_path()), regex_)) {
142 return [
this](net_http::ServerRequestInterface* req) {
143 this->ProcessRequest(req);
146 VLOG(1) <<
"Ignoring HTTP request: " << req->http_method() <<
" "
152 void ProcessRequest(net_http::ServerRequestInterface* req) {
153 const uint64_t start = Env::Default()->NowMicros();
155 int64_t num_bytes = 0;
156 auto request_chunk = req->ReadRequestBytes(&num_bytes);
157 while (request_chunk !=
nullptr) {
158 absl::StrAppend(&body, absl::string_view(request_chunk.get(), num_bytes));
159 request_chunk = req->ReadRequestBytes(&num_bytes);
162 std::vector<std::pair<string, string>> headers;
166 VLOG(1) <<
"Processing HTTP request: " << req->http_method() <<
" "
167 << req->uri_path() <<
" body: " << body.size() <<
" bytes.";
170 if (req->http_method() ==
"OPTIONS") {
171 absl::string_view origin_header = req->GetRequestHeader(
"Origin");
172 if (RE2::PartialMatch(origin_header,
"https?://")) {
173 status = absl::OkStatus();
175 status = errors::FailedPrecondition(
176 "Origin header is missing in CORS preflight");
180 handler_->ProcessRequest(req->http_method(), req->uri_path(), body,
181 &headers, &model_name, &method, &output);
183 if (core_->enable_cors_support()) {
184 AddCORSHeaders(&headers);
187 const auto http_status = ToHTTPStatusCode(status);
190 for (
const auto& kv : headers) {
191 req->OverwriteResponseHeader(kv.first, kv.second);
193 req->WriteResponseString(output);
194 if (http_status == net_http::HTTPStatusCode::OK) {
195 RecordRequestLatency(model_name, method,
"REST",
196 Env::Default()->NowMicros() - start);
198 VLOG(1) <<
"Error Processing HTTP/REST request: " << req->http_method()
199 <<
" " << req->uri_path() <<
" Error: " << status.ToString();
201 RecordModelRequestCount(model_name, status);
202 req->ReplyWithStatus(http_status);
207 std::unique_ptr<HttpRestApiHandlerBase> handler_;
212 std::unique_ptr<net_http::HTTPServerInterface> CreateAndStartHttpServer(
213 int port,
int num_threads,
int timeout_in_ms,
214 const MonitoringConfig& monitoring_config, ServerCore* core) {
215 auto options = absl::make_unique<net_http::ServerOptions>();
216 options->AddPort(
static_cast<uint32_t
>(port));
217 options->SetExecutor(absl::make_unique<RequestExecutor>(num_threads));
219 auto server = net_http::CreateEvHTTPServer(std::move(options));
220 if (server ==
nullptr) {
225 if (monitoring_config.prometheus_config().enable()) {
226 std::shared_ptr<PrometheusExporter> exporter =
227 std::make_shared<PrometheusExporter>();
228 net_http::RequestHandlerOptions prometheus_request_options;
229 PrometheusConfig prometheus_config = monitoring_config.prometheus_config();
230 auto path = prometheus_config.path().empty()
231 ? PrometheusExporter::kPrometheusPath
232 : prometheus_config.path();
233 server->RegisterRequestHandler(
235 [exporter, path](net_http::ServerRequestInterface* req) {
236 ProcessPrometheusRequest(exporter.get(), path, req);
238 prometheus_request_options);
241 std::shared_ptr<RestApiRequestDispatcher> dispatcher =
242 std::make_shared<RestApiRequestDispatcher>(timeout_in_ms, core);
243 net_http::RequestHandlerOptions handler_options;
244 server->RegisterRequestDispatcher(
245 [dispatcher](net_http::ServerRequestInterface* req) {
246 return dispatcher->Dispatch(req);
249 if (server->StartAcceptingRequests()) {