TensorFlow Serving C++ API Documentation
tfrt_prediction_service_impl.cc
1 /* Copyright 2022 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_serving/model_servers/tfrt_prediction_service_impl.h"
16 
17 #include "grpc/grpc.h"
18 #include "grpcpp/server_context.h"
19 #include "absl/time/clock.h"
20 #include "absl/time/time.h"
21 #include "tsl/platform/errors.h"
22 #include "tensorflow_serving/apis/inference.pb.h"
23 #include "tensorflow_serving/apis/model.pb.h"
24 #include "tensorflow_serving/core/servable_handle.h"
25 #include "tensorflow_serving/model_servers/grpc_status_util.h"
26 #include "tensorflow_serving/model_servers/prediction_service_util.h"
27 #include "tensorflow_serving/servables/tensorflow/servable.h"
28 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
29 #include "tensorflow_serving/servables/tensorflow/util.h"
30 
31 namespace tensorflow {
32 namespace serving {
33 
34 absl::Time TfrtPredictionServiceImpl::GetRequestDeadline(
35  ::grpc::ServerContext *context) const {
36  if (enforce_session_run_timeout_) {
37  return absl::Now() +
38  absl::Milliseconds(DeadlineToTimeoutMillis(context->raw_deadline()));
39  }
40  return absl::InfiniteFuture();
41 }
42 
43 ::grpc::Status TfrtPredictionServiceImpl::Predict(
44  ::grpc::ServerContext *context, const PredictRequest *request,
45  PredictResponse *response) {
46  const uint64_t start = Env::Default()->NowMicros();
47 
48  Servable::RunOptions run_options;
49  run_options.deadline = GetRequestDeadline(context);
50  ServableHandle<Servable> servable;
51  auto tf_status = core_->GetServableHandle(request->model_spec(), &servable);
52  if (!tf_status.ok()) {
53  VLOG(1) << "TFRT Predict get servable handle failed: "
54  << tf_status.message();
55  return ToGRPCStatus(tf_status);
56  }
57 
58  tf_status = servable->Predict(run_options, *request, response);
59 
60  const ::grpc::Status status = ToGRPCStatus(tf_status);
61 
62  if (status.ok()) {
63  RecordRequestLatency(request->model_spec().name(), /*api=*/"Predict",
64  /*entrypoint=*/"GRPC",
65  Env::Default()->NowMicros() - start);
66  } else {
67  VLOG(1) << "TFRT Predict failed: " << status.error_message();
68  }
69  RecordModelRequestCount(request->model_spec().name(), tf_status);
70 
71  return status;
72 }
73 
74 ::grpc::Status TfrtPredictionServiceImpl::GetModelMetadata(
75  ::grpc::ServerContext *context, const GetModelMetadataRequest *request,
76  GetModelMetadataResponse *response) {
77  const ::tensorflow::Status tf_status =
78  TFRTGetModelMetadataImpl::GetModelMetadata(core_, *request, response);
79  const ::grpc::Status status = ToGRPCStatus(tf_status);
80  if (!status.ok()) {
81  VLOG(1) << "TFRT GetModelMetadata failed: " << status.error_message();
82  }
83  return status;
84 }
85 
86 ::grpc::Status TfrtPredictionServiceImpl::Classify(
87  ::grpc::ServerContext *context, const ClassificationRequest *request,
88  ClassificationResponse *response) {
89  const uint64_t start = Env::Default()->NowMicros();
90 
91  Servable::RunOptions run_options;
92  run_options.deadline = GetRequestDeadline(context);
93  ServableHandle<Servable> servable;
94  auto tf_status = core_->GetServableHandle(request->model_spec(), &servable);
95  if (!tf_status.ok()) {
96  VLOG(1) << "TFRT Classify get servable handle failed: "
97  << tf_status.message();
98  return ToGRPCStatus(tf_status);
99  }
100  tf_status = servable->Classify(run_options, *request, response);
101 
102  const ::grpc::Status status = ToGRPCStatus(tf_status);
103 
104  if (status.ok()) {
105  RecordRequestLatency(request->model_spec().name(), /*api=*/"Classify",
106  /*entrypoint=*/"GRPC",
107  Env::Default()->NowMicros() - start);
108  } else {
109  VLOG(1) << "TFRT Classify request failed: " << status.error_message();
110  }
111  RecordModelRequestCount(request->model_spec().name(), tf_status);
112 
113  return status;
114 }
115 
116 ::grpc::Status TfrtPredictionServiceImpl::Regress(
117  ::grpc::ServerContext *context, const RegressionRequest *request,
118  RegressionResponse *response) {
119  const uint64_t start = Env::Default()->NowMicros();
120 
121  Servable::RunOptions run_options;
122  run_options.deadline = GetRequestDeadline(context);
123  ServableHandle<Servable> servable;
124  auto tf_status = core_->GetServableHandle(request->model_spec(), &servable);
125  if (!tf_status.ok()) {
126  VLOG(1) << "TFRT Regress get servable handle failed: "
127  << tf_status.message();
128  return ToGRPCStatus(tf_status);
129  }
130 
131  tf_status = servable->Regress(run_options, *request, response);
132 
133  const ::grpc::Status status = ToGRPCStatus(tf_status);
134 
135  if (status.ok()) {
136  RecordRequestLatency(request->model_spec().name(), /*api=*/"Regress",
137  /*entrypoint=*/"GRPC",
138  Env::Default()->NowMicros() - start);
139  } else {
140  VLOG(1) << "TFRT Regress request failed: " << status.error_message();
141  }
142  RecordModelRequestCount(request->model_spec().name(), tf_status);
143 
144  return status;
145 }
146 
147 namespace {
148 
149 const ModelSpec &GetModelSpecFromRequest(const MultiInferenceRequest &request) {
150  if (request.tasks_size() > 0 && request.tasks(0).has_model_spec()) {
151  return request.tasks(0).model_spec();
152  }
153  return ModelSpec::default_instance();
154 }
155 
156 } // namespace
157 
158 ::grpc::Status TfrtPredictionServiceImpl::MultiInference(
159  ::grpc::ServerContext *context, const MultiInferenceRequest *request,
160  MultiInferenceResponse *response) {
161  Servable::RunOptions run_options;
162  run_options.deadline = GetRequestDeadline(context);
163  ServableHandle<Servable> servable;
164 
165  auto tf_status =
166  core_->GetServableHandle(GetModelSpecFromRequest(*request), &servable);
167  if (!tf_status.ok()) {
168  VLOG(1) << "TFRT MultiInference get model spec from request failed: "
169  << tf_status.message();
170  return ToGRPCStatus(tf_status);
171  }
172 
173  tf_status = servable->MultiInference(run_options, *request, response);
174 
175  const ::grpc::Status status = ToGRPCStatus(tf_status);
176  if (!status.ok()) {
177  VLOG(1) << "TFRT MultiInference request failed: " << status.error_message();
178  }
179  return status;
180 }
181 
182 } // namespace serving
183 } // namespace tensorflow
Status GetServableHandle(const ModelSpec &model_spec, ServableHandle< T > *const handle)
Definition: server_core.h:267