TensorFlow Serving C++ API Documentation
prediction_service_impl.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/prediction_service_impl.h"
17 
18 #include "grpc/grpc.h"
19 #include "tensorflow_serving/model_servers/grpc_status_util.h"
20 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
21 #include "tensorflow_serving/servables/tensorflow/get_model_metadata_impl.h"
22 #include "tensorflow_serving/servables/tensorflow/multi_inference_helper.h"
23 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
24 #include "tensorflow_serving/servables/tensorflow/util.h"
25 
26 namespace tensorflow {
27 namespace serving {
28 
29 namespace {
30 
31 ScopedThreadPools GetThreadPools(ThreadPoolFactory *thread_pool_factory) {
32  return thread_pool_factory == nullptr ? ScopedThreadPools()
33  : thread_pool_factory->GetThreadPools();
34 }
35 
36 } // namespace
37 
38 ::grpc::Status PredictionServiceImpl::Predict(::grpc::ServerContext *context,
39  const PredictRequest *request,
40  PredictResponse *response) {
41  const uint64_t start = Env::Default()->NowMicros();
42  tensorflow::RunOptions run_options = tensorflow::RunOptions();
43  if (enforce_session_run_timeout_) {
44  run_options.set_timeout_in_ms(
45  DeadlineToTimeoutMillis(context->raw_deadline()));
46  }
47 
48  const ::tensorflow::Status tf_status =
49  predictor_->Predict(run_options, core_, *request, response);
50  const ::grpc::Status status = ToGRPCStatus(tf_status);
51 
52  if (status.ok()) {
53  RecordRequestLatency(request->model_spec().name(), /*api=*/"Predict",
54  /*entrypoint=*/"GRPC",
55  Env::Default()->NowMicros() - start);
56  } else {
57  VLOG(1) << "Predict failed: " << status.error_message();
58  }
59  RecordModelRequestCount(request->model_spec().name(), tf_status);
60 
61  return status;
62 }
63 
64 ::grpc::Status PredictionServiceImpl::GetModelMetadata(
65  ::grpc::ServerContext *context, const GetModelMetadataRequest *request,
66  GetModelMetadataResponse *response) {
67  const ::grpc::Status status = ToGRPCStatus(
68  GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
69  if (!status.ok()) {
70  VLOG(1) << "GetModelMetadata failed: " << status.error_message();
71  }
72  return status;
73 }
74 
75 ::grpc::Status PredictionServiceImpl::Classify(
76  ::grpc::ServerContext *context, const ClassificationRequest *request,
77  ClassificationResponse *response) {
78  const uint64_t start = Env::Default()->NowMicros();
79  tensorflow::RunOptions run_options = tensorflow::RunOptions();
80  // By default, this is infinite which is the same default as RunOptions.
81  if (enforce_session_run_timeout_) {
82  run_options.set_timeout_in_ms(
83  DeadlineToTimeoutMillis(context->raw_deadline()));
84  }
85 
86  const ::tensorflow::Status tf_status =
87  TensorflowClassificationServiceImpl::Classify(
88  run_options, core_, GetThreadPools(thread_pool_factory_).get(),
89  *request, response);
90  const ::grpc::Status status = ToGRPCStatus(tf_status);
91 
92  if (status.ok()) {
93  RecordRequestLatency(request->model_spec().name(), /*api=*/"Classify",
94  /*entrypoint=*/"GRPC",
95  Env::Default()->NowMicros() - start);
96  } else {
97  VLOG(1) << "Classify request failed: " << status.error_message();
98  }
99  RecordModelRequestCount(request->model_spec().name(), tf_status);
100 
101  return status;
102 }
103 
104 ::grpc::Status PredictionServiceImpl::Regress(::grpc::ServerContext *context,
105  const RegressionRequest *request,
106  RegressionResponse *response) {
107  const uint64_t start = Env::Default()->NowMicros();
108  tensorflow::RunOptions run_options = tensorflow::RunOptions();
109  // By default, this is infinite which is the same default as RunOptions.
110  if (enforce_session_run_timeout_) {
111  run_options.set_timeout_in_ms(
112  DeadlineToTimeoutMillis(context->raw_deadline()));
113  }
114 
115  const ::tensorflow::Status tf_status =
116  TensorflowRegressionServiceImpl::Regress(
117  run_options, core_, GetThreadPools(thread_pool_factory_).get(),
118  *request, response);
119  const ::grpc::Status status = ToGRPCStatus(tf_status);
120 
121  if (status.ok()) {
122  RecordRequestLatency(request->model_spec().name(), /*api=*/"Regress",
123  /*entrypoint=*/"GRPC",
124  Env::Default()->NowMicros() - start);
125  } else {
126  VLOG(1) << "Regress request failed: " << status.error_message();
127  }
128  RecordModelRequestCount(request->model_spec().name(), tf_status);
129 
130  return status;
131 }
132 
133 ::grpc::Status PredictionServiceImpl::MultiInference(
134  ::grpc::ServerContext *context, const MultiInferenceRequest *request,
135  MultiInferenceResponse *response) {
136  tensorflow::RunOptions run_options = tensorflow::RunOptions();
137  // By default, this is infinite which is the same default as RunOptions.
138  if (enforce_session_run_timeout_) {
139  run_options.set_timeout_in_ms(
140  DeadlineToTimeoutMillis(context->raw_deadline()));
141  }
142  const ::grpc::Status status = ToGRPCStatus(RunMultiInferenceWithServerCore(
143  run_options, core_, GetThreadPools(thread_pool_factory_).get(), *request,
144  response));
145  if (!status.ok()) {
146  VLOG(1) << "MultiInference request failed: " << status.error_message();
147  }
148  return status;
149 }
150 
151 } // namespace serving
152 } // namespace tensorflow