TensorFlow Serving C++ API Documentation
remote_predict_op_kernel.h
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_
16 #define TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_
17 
18 #include "google/protobuf/wrappers.pb.h"
19 #include "google/protobuf/map.h"
20 #include "absl/status/status.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/shape_inference.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/platform/status.h"
33 #include "tensorflow/core/protobuf/named_tensor.pb.h"
34 #include "tensorflow_serving/apis/model.pb.h"
35 #include "tensorflow_serving/apis/predict.pb.h"
36 
37 namespace tensorflow {
38 namespace serving {
39 
40 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> AliasTensorMap;
41 
42 // Remote Predict Op kernel implementation class templated on different
43 // PredictionServiceStubTypes.
44 template <typename PredictionServiceStubType>
45 class RemotePredictOp : public AsyncOpKernel {
46  public:
47  explicit RemotePredictOp(OpKernelConstruction* context)
48  : AsyncOpKernel(context) {
49  string target_address;
50  OP_REQUIRES_OK(context,
51  context->GetAttr("target_address", &target_address));
52  OP_REQUIRES_OK(context, context->GetAttr("model_name", &model_name_));
53  OP_REQUIRES_OK(context, context->GetAttr("model_version", &model_version_));
54  OP_REQUIRES_OK(context, context->GetAttr("max_rpc_deadline_millis",
55  &max_rpc_deadline_millis_));
56  OP_REQUIRES_OK(context, context->GetAttr("fail_op_on_rpc_error",
57  &fail_op_on_rpc_error_));
58  OP_REQUIRES_OK(context,
59  context->GetAttr("signature_name", &signature_name_));
60  absl::Status prediction_service_status =
61  PredictionServiceStubType::Create(target_address, &prediction_service_);
62  OP_REQUIRES(context, prediction_service_status.ok(),
63  tensorflow::Status(static_cast<tensorflow::errors::Code>(
64  prediction_service_status.code()),
65  prediction_service_status.message()));
66  }
67 
68  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
69  // Get the input tensor alias names.
70  const auto& input_tensor_aliases = context->input(0).flat<tstring>();
71 
72  // Get the input tensors.
73  OpInputList input_tensors;
74  OP_REQUIRES_OK_ASYNC(
75  context, context->input_list("input_tensors", &input_tensors), done);
76  // Get the output tensor alias names.
77  // Directly index to output_tensor_aliases by moving past all the input
78  // before it, including the input_tensor_aliases and input_tensors.
79  auto output_tensor_aliases =
80  context->input(1 + input_tensors.size()).flat<tstring>();
81 
82  // Build the PredictRequest.
83  PredictRequest* request = new PredictRequest();
84 
85  request->mutable_model_spec()->set_name(model_name_);
86 
87  request->mutable_model_spec()->set_signature_name(signature_name_);
88 
89  if (model_version_ >= 0) {
90  request->mutable_model_spec()->mutable_version()->set_value(
91  model_version_);
92  }
93 
94  AliasTensorMap& inputs = *request->mutable_inputs();
95  for (int i = 0; i < input_tensor_aliases.size(); ++i) {
96  tensorflow::TensorProto proto;
97  input_tensors[i].AsProtoField(&proto);
98  inputs[input_tensor_aliases(i)] = proto;
99  }
100 
101  for (int i = 0; i < output_tensor_aliases.size(); ++i) {
102  request->add_output_filter(tensorflow::string(output_tensor_aliases(i)));
103  }
104 
105  PredictResponse* response = new PredictResponse();
106 
107  auto rpc_or = prediction_service_->CreateRpc(
108  absl::Milliseconds(max_rpc_deadline_millis_));
109  OP_REQUIRES_ASYNC(context, rpc_or.ok(),
110  tensorflow::Status(static_cast<tensorflow::errors::Code>(
111  rpc_or.status().code()),
112  rpc_or.status().message()),
113  [&]() {
114  delete request;
115  delete response;
116  done();
117  });
118  auto rpc = rpc_or.value();
119  auto callback = [this, context, rpc, request, response,
120  output_tensor_aliases, done](const absl::Status& status) {
121  PostProcessResponse(context, response, status, fail_op_on_rpc_error_,
122  output_tensor_aliases, [&]() {
123  delete rpc;
124  delete request;
125  delete response;
126  done();
127  });
128  };
129  // Make the RPC call.
130  prediction_service_->Predict(rpc, request, response, callback);
131  }
132 
133  void PostProcessResponse(OpKernelContext* context, PredictResponse* response,
134  const absl::Status& rpc_status,
135  bool fail_op_on_rpc_error,
136  TTypes<const tstring>::Flat output_tensor_aliases,
137  DoneCallback rpc_done) {
138  auto rpc_cleaner = gtl::MakeCleanup([&] { rpc_done(); });
139  Tensor* status_code;
140  OP_REQUIRES_OK_ASYNC(
141  context, context->allocate_output(0, TensorShape({}), &status_code),
142  rpc_cleaner.release());
143  status_code->scalar<int>()() = static_cast<int>(rpc_status.code());
144  Tensor* status_error_message;
145  OP_REQUIRES_OK_ASYNC(
146  context,
147  context->allocate_output(1, TensorShape({}), &status_error_message),
148  rpc_cleaner.release());
149  status_error_message->scalar<tstring>()() = rpc_status.message();
150  OpOutputList output_tensors_list;
151  OP_REQUIRES_OK_ASYNC(
152  context, context->output_list("output_tensors", &output_tensors_list),
153  rpc_cleaner.release());
154  // Process the response.
155  if (!rpc_status.ok()) {
156  if (fail_op_on_rpc_error) {
157  OP_REQUIRES_OK_ASYNC(
158  context,
159  tensorflow::Status(
160  static_cast<tensorflow::errors::Code>(rpc_status.code()),
161  rpc_status.message()),
162  rpc_cleaner.release());
163  } else {
164  // Allocate some empty output for the output_tensors.
165  for (int i = 0; i < output_tensors_list.size(); ++i) {
166  Tensor* unused;
167  OP_REQUIRES_OK_ASYNC(
168  context,
169  output_tensors_list.allocate(i, TensorShape({}), &unused),
170  rpc_cleaner.release());
171  }
172  return;
173  }
174  }
175  OP_REQUIRES_ASYNC(
176  context, output_tensors_list.size() == output_tensor_aliases.size(),
177  errors::Internal(
178  "Response doesn't have the right number of outputs; actual: ",
179  output_tensors_list.size(),
180  " expected: ", output_tensor_aliases.size()),
181  rpc_cleaner.release());
182  AliasTensorMap& outputs = *response->mutable_outputs();
183  for (int i = 0; i < output_tensor_aliases.size(); i++) {
184  Tensor output_tensor;
185  OP_REQUIRES_ASYNC(
186  context, output_tensor.FromProto(outputs[output_tensor_aliases(i)]),
187  errors::Internal("Response tensor proto: ",
188  tensorflow::string(output_tensor_aliases(i)),
189  " cannot be converted back to a tensor."),
190  rpc_cleaner.release());
191  output_tensors_list.set(i, output_tensor);
192  }
193  }
194 
195  private:
196  string model_name_;
197  int64_t model_version_;
198  bool fail_op_on_rpc_error_;
199  int64_t max_rpc_deadline_millis_;
200  string signature_name_;
201  std::unique_ptr<PredictionServiceStubType> prediction_service_;
202 };
203 
204 } // namespace serving
205 } // namespace tensorflow
206 #endif // TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_