15 #include "tensorflow_serving/model_servers/tfrt_prediction_service_impl.h"
17 #include "grpc/grpc.h"
18 #include "grpcpp/server_context.h"
19 #include "absl/time/clock.h"
20 #include "absl/time/time.h"
21 #include "tsl/platform/errors.h"
22 #include "tensorflow_serving/apis/inference.pb.h"
23 #include "tensorflow_serving/apis/model.pb.h"
24 #include "tensorflow_serving/core/servable_handle.h"
25 #include "tensorflow_serving/model_servers/grpc_status_util.h"
26 #include "tensorflow_serving/model_servers/prediction_service_util.h"
27 #include "tensorflow_serving/servables/tensorflow/servable.h"
28 #include "tensorflow_serving/servables/tensorflow/tfrt_get_model_metadata_impl.h"
29 #include "tensorflow_serving/servables/tensorflow/util.h"
31 namespace tensorflow {
34 absl::Time TfrtPredictionServiceImpl::GetRequestDeadline(
35 ::grpc::ServerContext *context)
const {
36 if (enforce_session_run_timeout_) {
38 absl::Milliseconds(DeadlineToTimeoutMillis(context->raw_deadline()));
40 return absl::InfiniteFuture();
43 ::grpc::Status TfrtPredictionServiceImpl::Predict(
44 ::grpc::ServerContext *context,
const PredictRequest *request,
45 PredictResponse *response) {
46 const uint64_t start = Env::Default()->NowMicros();
48 Servable::RunOptions run_options;
49 run_options.deadline = GetRequestDeadline(context);
50 ServableHandle<Servable> servable;
52 if (!tf_status.ok()) {
53 VLOG(1) <<
"TFRT Predict get servable handle failed: "
54 << tf_status.message();
55 return ToGRPCStatus(tf_status);
58 tf_status = servable->Predict(run_options, *request, response);
60 const ::grpc::Status status = ToGRPCStatus(tf_status);
63 RecordRequestLatency(request->model_spec().name(),
"Predict",
65 Env::Default()->NowMicros() - start);
67 VLOG(1) <<
"TFRT Predict failed: " << status.error_message();
69 RecordModelRequestCount(request->model_spec().name(), tf_status);
74 ::grpc::Status TfrtPredictionServiceImpl::GetModelMetadata(
75 ::grpc::ServerContext *context,
const GetModelMetadataRequest *request,
76 GetModelMetadataResponse *response) {
77 const ::tensorflow::Status tf_status =
78 TFRTGetModelMetadataImpl::GetModelMetadata(core_, *request, response);
79 const ::grpc::Status status = ToGRPCStatus(tf_status);
81 VLOG(1) <<
"TFRT GetModelMetadata failed: " << status.error_message();
86 ::grpc::Status TfrtPredictionServiceImpl::Classify(
87 ::grpc::ServerContext *context,
const ClassificationRequest *request,
88 ClassificationResponse *response) {
89 const uint64_t start = Env::Default()->NowMicros();
91 Servable::RunOptions run_options;
92 run_options.deadline = GetRequestDeadline(context);
93 ServableHandle<Servable> servable;
95 if (!tf_status.ok()) {
96 VLOG(1) <<
"TFRT Classify get servable handle failed: "
97 << tf_status.message();
98 return ToGRPCStatus(tf_status);
100 tf_status = servable->Classify(run_options, *request, response);
102 const ::grpc::Status status = ToGRPCStatus(tf_status);
105 RecordRequestLatency(request->model_spec().name(),
"Classify",
107 Env::Default()->NowMicros() - start);
109 VLOG(1) <<
"TFRT Classify request failed: " << status.error_message();
111 RecordModelRequestCount(request->model_spec().name(), tf_status);
116 ::grpc::Status TfrtPredictionServiceImpl::Regress(
117 ::grpc::ServerContext *context,
const RegressionRequest *request,
118 RegressionResponse *response) {
119 const uint64_t start = Env::Default()->NowMicros();
121 Servable::RunOptions run_options;
122 run_options.deadline = GetRequestDeadline(context);
123 ServableHandle<Servable> servable;
125 if (!tf_status.ok()) {
126 VLOG(1) <<
"TFRT Regress get servable handle failed: "
127 << tf_status.message();
128 return ToGRPCStatus(tf_status);
131 tf_status = servable->Regress(run_options, *request, response);
133 const ::grpc::Status status = ToGRPCStatus(tf_status);
136 RecordRequestLatency(request->model_spec().name(),
"Regress",
138 Env::Default()->NowMicros() - start);
140 VLOG(1) <<
"TFRT Regress request failed: " << status.error_message();
142 RecordModelRequestCount(request->model_spec().name(), tf_status);
149 const ModelSpec &GetModelSpecFromRequest(
const MultiInferenceRequest &request) {
150 if (request.tasks_size() > 0 && request.tasks(0).has_model_spec()) {
151 return request.tasks(0).model_spec();
153 return ModelSpec::default_instance();
158 ::grpc::Status TfrtPredictionServiceImpl::MultiInference(
159 ::grpc::ServerContext *context,
const MultiInferenceRequest *request,
160 MultiInferenceResponse *response) {
161 Servable::RunOptions run_options;
162 run_options.deadline = GetRequestDeadline(context);
163 ServableHandle<Servable> servable;
167 if (!tf_status.ok()) {
168 VLOG(1) <<
"TFRT MultiInference get model spec from request failed: "
169 << tf_status.message();
170 return ToGRPCStatus(tf_status);
173 tf_status = servable->MultiInference(run_options, *request, response);
175 const ::grpc::Status status = ToGRPCStatus(tf_status);
177 VLOG(1) <<
"TFRT MultiInference request failed: " << status.error_message();
Status GetServableHandle(const ModelSpec &model_spec, ServableHandle< T > *const handle)