16 #include "tensorflow_serving/model_servers/tfrt_http_rest_api_handler.h"
22 #include "google/protobuf/any.pb.h"
23 #include "absl/status/status.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/time/clock.h"
30 #include "absl/time/time.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/cc/saved_model/signature_constants.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/protobuf/meta_graph.pb.h"
36 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
37 #include "tsl/platform/errors.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/predict.pb.h"
40 #include "tensorflow_serving/core/servable_handle.h"
41 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
42 #include "tensorflow_serving/model_servers/http_rest_api_util.h"
43 #include "tensorflow_serving/model_servers/server_core.h"
44 #include "tensorflow_serving/servables/tensorflow/servable.h"
45 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
46 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
47 #include "tensorflow_serving/util/json_tensor.h"
49 namespace tensorflow {
54 const char*
const TFRTHttpRestApiHandler::kPathRegex =
55 kHTTPRestApiHandlerPathRegex;
57 TFRTHttpRestApiHandler::TFRTHttpRestApiHandler(
int timeout_in_ms,
60 timeout_(absl::Milliseconds(timeout_in_ms)),
63 TFRTHttpRestApiHandler::~TFRTHttpRestApiHandler() {}
65 Status TFRTHttpRestApiHandler::ProcessRequest(
66 const absl::string_view http_method,
const absl::string_view request_path,
67 const absl::string_view request_body,
68 std::vector<std::pair<std::string, std::string>>* headers,
69 std::string* model_name, std::string* method, std::string* output) {
74 std::string model_subresource;
75 Status status = errors::InvalidArgument(
"Malformed request: ", http_method,
77 absl::optional<int64_t> model_version;
78 absl::optional<std::string> model_version_label;
79 bool parse_successful;
81 TF_RETURN_IF_ERROR(ParseModelInfo(
82 http_method, request_path, model_name, &model_version,
83 &model_version_label, method, &model_subresource, &parse_successful));
85 auto run_options = run_options_;
86 run_options.deadline = absl::Now() + timeout_;
89 if (http_method ==
"POST" && parse_successful) {
90 if (*method ==
"classify") {
91 status = ProcessClassifyRequest(*model_name, model_version,
92 model_version_label, request_body,
94 }
else if (*method ==
"regress") {
96 ProcessRegressRequest(*model_name, model_version, model_version_label,
97 request_body, run_options, output);
98 }
else if (*method ==
"predict") {
100 ProcessPredictRequest(*model_name, model_version, model_version_label,
101 request_body, run_options, output);
103 }
else if (http_method ==
"GET" && parse_successful) {
104 if (!model_subresource.empty() && model_subresource ==
"metadata") {
105 status = ProcessModelMetadataRequest(*model_name, model_version,
106 model_version_label, output);
108 status = ProcessModelStatusRequest(
109 *model_name, model_version, model_version_label, run_options, output);
113 MakeJsonFromStatus(status, output);
117 Status TFRTHttpRestApiHandler::ProcessClassifyRequest(
118 const absl::string_view model_name,
119 const absl::optional<int64_t>& model_version,
120 const absl::optional<absl::string_view>& model_version_label,
121 const absl::string_view request_body,
122 const Servable::RunOptions& run_options, std::string* output) {
123 ::google::protobuf::Arena arena;
125 auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
126 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
127 model_name, model_version, model_version_label,
128 request->mutable_model_spec()));
129 TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));
131 auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
132 ServableHandle<Servable> servable;
134 core_->GetServableHandle(request->model_spec(), &servable));
135 TF_RETURN_IF_ERROR(servable->Classify(run_options, *request, response));
137 MakeJsonFromClassificationResult(response->result(), output));
141 Status TFRTHttpRestApiHandler::ProcessRegressRequest(
142 const absl::string_view model_name,
143 const absl::optional<int64_t>& model_version,
144 const absl::optional<absl::string_view>& model_version_label,
145 const absl::string_view request_body,
146 const Servable::RunOptions& run_options, std::string* output) {
147 ::google::protobuf::Arena arena;
149 auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
150 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
151 model_name, model_version, model_version_label,
152 request->mutable_model_spec()));
153 TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));
155 auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
156 ServableHandle<Servable> servable;
158 core_->GetServableHandle(request->model_spec(), &servable));
159 TF_RETURN_IF_ERROR(servable->Regress(run_options, *request, response));
160 return MakeJsonFromRegressionResult(response->result(), output);
163 Status TFRTHttpRestApiHandler::ProcessPredictRequest(
164 const absl::string_view model_name,
165 const absl::optional<int64_t>& model_version,
166 const absl::optional<absl::string_view>& model_version_label,
167 const absl::string_view request_body,
168 const Servable::RunOptions& run_options, std::string* output) {
169 ::google::protobuf::Arena arena;
171 auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
172 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
173 model_name, model_version, model_version_label,
174 request->mutable_model_spec()));
176 JsonPredictRequestFormat format;
177 TF_RETURN_IF_ERROR(FillPredictRequestFromJson(
179 [
this, request](
const std::string& sig,
180 ::google::protobuf::Map<std::string, TensorInfo>* map) {
181 return this->GetInfoMap(request->model_spec(), sig, map);
185 auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);
187 ServableHandle<Servable> servable;
189 core_->GetServableHandle(request->model_spec(), &servable));
190 TF_RETURN_IF_ERROR(servable->Predict(run_options, *request, response));
192 TF_RETURN_IF_ERROR(MakeJsonFromTensors(response->outputs(), format, output));
196 Status TFRTHttpRestApiHandler::ProcessModelStatusRequest(
197 const absl::string_view model_name,
198 const absl::optional<int64_t>& model_version,
199 const absl::optional<absl::string_view>& model_version_label,
200 const Servable::RunOptions& run_options, std::string* output) {
203 if (model_name.empty()) {
204 return errors::InvalidArgument(
"Missing model name in request.");
207 ::google::protobuf::Arena arena;
209 auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
210 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
211 model_name, model_version, model_version_label,
212 request->mutable_model_spec()));
214 auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
216 GetModelStatusImpl::GetModelStatus(core_, *request, response));
217 return ToJsonString(*response, output);
220 Status TFRTHttpRestApiHandler::ProcessModelMetadataRequest(
221 const absl::string_view model_name,
222 const absl::optional<int64_t>& model_version,
223 const absl::optional<absl::string_view>& model_version_label,
224 std::string* output) {
225 if (model_name.empty()) {
226 return errors::InvalidArgument(
"Missing model name in request.");
229 ::google::protobuf::Arena arena;
231 auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
233 request->add_metadata_field(std::string(kSignatureDef));
234 TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
235 model_name, model_version, model_version_label,
236 request->mutable_model_spec()));
238 auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
240 TFRTGetModelMetadataImpl::GetModelMetadata(core_, *request, response));
242 return ToJsonString(*response, output);
245 Status TFRTHttpRestApiHandler::GetInfoMap(
246 const ModelSpec& model_spec,
const std::string& signature_name,
247 ::google::protobuf::Map<std::string, tensorflow::TensorInfo>* infomap) {
248 ServableHandle<Servable> servable;
249 TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &servable));
251 down_cast<TfrtSavedModelServable*>(servable.get())->saved_model();
252 const std::string& signame =
253 signature_name.empty() ? kDefaultServingSignatureDefKey : signature_name;
254 auto iter = saved_model.GetMetaGraphDef().signature_def().find(signame);
255 if (iter == saved_model.GetMetaGraphDef().signature_def().end()) {
256 return errors::InvalidArgument(
"Serving signature name: \"", signame,
257 "\" not found in signature def");
259 *infomap = iter->second.inputs();