16 #include "tensorflow_serving/model_servers/prediction_service_impl.h"
18 #include "grpc/grpc.h"
19 #include "tensorflow_serving/model_servers/grpc_status_util.h"
20 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
21 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
22 #include "tensorflow_serving/servables/tensorflow/multi_inference_helper.h"
23 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
24 #include "tensorflow_serving/servables/tensorflow/util.h"
26 namespace tensorflow {
31 ScopedThreadPools GetThreadPools(ThreadPoolFactory *thread_pool_factory) {
32 return thread_pool_factory ==
nullptr ? ScopedThreadPools()
33 : thread_pool_factory->GetThreadPools();
38 ::grpc::Status PredictionServiceImpl::Predict(::grpc::ServerContext *context,
39 const PredictRequest *request,
40 PredictResponse *response) {
41 const uint64_t start = Env::Default()->NowMicros();
42 tensorflow::RunOptions run_options = tensorflow::RunOptions();
43 if (enforce_session_run_timeout_) {
44 run_options.set_timeout_in_ms(
45 DeadlineToTimeoutMillis(context->raw_deadline()));
48 const ::tensorflow::Status tf_status =
49 predictor_->Predict(run_options, core_, *request, response);
50 const ::grpc::Status status = ToGRPCStatus(tf_status);
53 RecordRequestLatency(request->model_spec().name(),
"Predict",
55 Env::Default()->NowMicros() - start);
57 VLOG(1) <<
"Predict failed: " << status.error_message();
59 RecordModelRequestCount(request->model_spec().name(), tf_status);
64 ::grpc::Status PredictionServiceImpl::GetModelMetadata(
65 ::grpc::ServerContext *context,
const GetModelMetadataRequest *request,
66 GetModelMetadataResponse *response) {
67 const ::grpc::Status status = ToGRPCStatus(
68 GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
70 VLOG(1) <<
"GetModelMetadata failed: " << status.error_message();
75 ::grpc::Status PredictionServiceImpl::Classify(
76 ::grpc::ServerContext *context,
const ClassificationRequest *request,
77 ClassificationResponse *response) {
78 const uint64_t start = Env::Default()->NowMicros();
79 tensorflow::RunOptions run_options = tensorflow::RunOptions();
81 if (enforce_session_run_timeout_) {
82 run_options.set_timeout_in_ms(
83 DeadlineToTimeoutMillis(context->raw_deadline()));
86 const ::tensorflow::Status tf_status =
87 TensorflowClassificationServiceImpl::Classify(
88 run_options, core_, GetThreadPools(thread_pool_factory_).get(),
90 const ::grpc::Status status = ToGRPCStatus(tf_status);
93 RecordRequestLatency(request->model_spec().name(),
"Classify",
95 Env::Default()->NowMicros() - start);
97 VLOG(1) <<
"Classify request failed: " << status.error_message();
99 RecordModelRequestCount(request->model_spec().name(), tf_status);
104 ::grpc::Status PredictionServiceImpl::Regress(::grpc::ServerContext *context,
105 const RegressionRequest *request,
106 RegressionResponse *response) {
107 const uint64_t start = Env::Default()->NowMicros();
108 tensorflow::RunOptions run_options = tensorflow::RunOptions();
110 if (enforce_session_run_timeout_) {
111 run_options.set_timeout_in_ms(
112 DeadlineToTimeoutMillis(context->raw_deadline()));
115 const ::tensorflow::Status tf_status =
116 TensorflowRegressionServiceImpl::Regress(
117 run_options, core_, GetThreadPools(thread_pool_factory_).get(),
119 const ::grpc::Status status = ToGRPCStatus(tf_status);
122 RecordRequestLatency(request->model_spec().name(),
"Regress",
124 Env::Default()->NowMicros() - start);
126 VLOG(1) <<
"Regress request failed: " << status.error_message();
128 RecordModelRequestCount(request->model_spec().name(), tf_status);
133 ::grpc::Status PredictionServiceImpl::MultiInference(
134 ::grpc::ServerContext *context,
const MultiInferenceRequest *request,
135 MultiInferenceResponse *response) {
136 tensorflow::RunOptions run_options = tensorflow::RunOptions();
138 if (enforce_session_run_timeout_) {
139 run_options.set_timeout_in_ms(
140 DeadlineToTimeoutMillis(context->raw_deadline()));
142 const ::grpc::Status status = ToGRPCStatus(RunMultiInferenceWithServerCore(
143 run_options, core_, GetThreadPools(thread_pool_factory_).get(), *request,
146 VLOG(1) <<
"MultiInference request failed: " << status.error_message();