15 #ifndef TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_
16 #define TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_
18 #include "google/protobuf/wrappers.pb.h"
19 #include "google/protobuf/map.h"
20 #include "absl/status/status.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/shape_inference.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/platform/status.h"
33 #include "tensorflow/core/protobuf/named_tensor.pb.h"
34 #include "tensorflow_serving/apis/model.pb.h"
35 #include "tensorflow_serving/apis/predict.pb.h"
37 namespace tensorflow {
40 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> AliasTensorMap;
44 template <
typename PredictionServiceStubType>
48 : AsyncOpKernel(context) {
49 string target_address;
50 OP_REQUIRES_OK(context,
51 context->GetAttr(
"target_address", &target_address));
52 OP_REQUIRES_OK(context, context->GetAttr(
"model_name", &model_name_));
53 OP_REQUIRES_OK(context, context->GetAttr(
"model_version", &model_version_));
54 OP_REQUIRES_OK(context, context->GetAttr(
"max_rpc_deadline_millis",
55 &max_rpc_deadline_millis_));
56 OP_REQUIRES_OK(context, context->GetAttr(
"fail_op_on_rpc_error",
57 &fail_op_on_rpc_error_));
58 OP_REQUIRES_OK(context,
59 context->GetAttr(
"signature_name", &signature_name_));
60 absl::Status prediction_service_status =
61 PredictionServiceStubType::Create(target_address, &prediction_service_);
62 OP_REQUIRES(context, prediction_service_status.ok(),
63 tensorflow::Status(
static_cast<tensorflow::errors::Code
>(
64 prediction_service_status.code()),
65 prediction_service_status.message()));
68 void ComputeAsync(OpKernelContext* context, DoneCallback done)
override {
70 const auto& input_tensor_aliases = context->input(0).flat<tstring>();
73 OpInputList input_tensors;
75 context, context->input_list(
"input_tensors", &input_tensors), done);
79 auto output_tensor_aliases =
80 context->input(1 + input_tensors.size()).flat<tstring>();
83 PredictRequest* request =
new PredictRequest();
85 request->mutable_model_spec()->set_name(model_name_);
87 request->mutable_model_spec()->set_signature_name(signature_name_);
89 if (model_version_ >= 0) {
90 request->mutable_model_spec()->mutable_version()->set_value(
94 AliasTensorMap& inputs = *request->mutable_inputs();
95 for (
int i = 0; i < input_tensor_aliases.size(); ++i) {
96 tensorflow::TensorProto proto;
97 input_tensors[i].AsProtoField(&proto);
98 inputs[input_tensor_aliases(i)] = proto;
101 for (
int i = 0; i < output_tensor_aliases.size(); ++i) {
102 request->add_output_filter(tensorflow::string(output_tensor_aliases(i)));
105 PredictResponse* response =
new PredictResponse();
107 auto rpc_or = prediction_service_->CreateRpc(
108 absl::Milliseconds(max_rpc_deadline_millis_));
109 OP_REQUIRES_ASYNC(context, rpc_or.ok(),
110 tensorflow::Status(
static_cast<tensorflow::errors::Code
>(
111 rpc_or.status().code()),
112 rpc_or.status().message()),
118 auto rpc = rpc_or.value();
119 auto callback = [
this, context, rpc, request, response,
120 output_tensor_aliases, done](
const absl::Status& status) {
121 PostProcessResponse(context, response, status, fail_op_on_rpc_error_,
122 output_tensor_aliases, [&]() {
130 prediction_service_->Predict(rpc, request, response, callback);
133 void PostProcessResponse(OpKernelContext* context, PredictResponse* response,
134 const absl::Status& rpc_status,
135 bool fail_op_on_rpc_error,
136 TTypes<const tstring>::Flat output_tensor_aliases,
137 DoneCallback rpc_done) {
138 auto rpc_cleaner = gtl::MakeCleanup([&] { rpc_done(); });
140 OP_REQUIRES_OK_ASYNC(
141 context, context->allocate_output(0, TensorShape({}), &status_code),
142 rpc_cleaner.release());
143 status_code->scalar<
int>()() =
static_cast<int>(rpc_status.code());
144 Tensor* status_error_message;
145 OP_REQUIRES_OK_ASYNC(
147 context->allocate_output(1, TensorShape({}), &status_error_message),
148 rpc_cleaner.release());
149 status_error_message->scalar<tstring>()() = rpc_status.message();
150 OpOutputList output_tensors_list;
151 OP_REQUIRES_OK_ASYNC(
152 context, context->output_list(
"output_tensors", &output_tensors_list),
153 rpc_cleaner.release());
155 if (!rpc_status.ok()) {
156 if (fail_op_on_rpc_error) {
157 OP_REQUIRES_OK_ASYNC(
160 static_cast<tensorflow::errors::Code
>(rpc_status.code()),
161 rpc_status.message()),
162 rpc_cleaner.release());
165 for (
int i = 0; i < output_tensors_list.size(); ++i) {
167 OP_REQUIRES_OK_ASYNC(
169 output_tensors_list.allocate(i, TensorShape({}), &unused),
170 rpc_cleaner.release());
176 context, output_tensors_list.size() == output_tensor_aliases.size(),
178 "Response doesn't have the right number of outputs; actual: ",
179 output_tensors_list.size(),
180 " expected: ", output_tensor_aliases.size()),
181 rpc_cleaner.release());
182 AliasTensorMap& outputs = *response->mutable_outputs();
183 for (
int i = 0; i < output_tensor_aliases.size(); i++) {
184 Tensor output_tensor;
186 context, output_tensor.FromProto(outputs[output_tensor_aliases(i)]),
187 errors::Internal(
"Response tensor proto: ",
188 tensorflow::string(output_tensor_aliases(i)),
189 " cannot be converted back to a tensor."),
190 rpc_cleaner.release());
191 output_tensors_list.set(i, output_tensor);
197 int64_t model_version_;
198 bool fail_op_on_rpc_error_;
199 int64_t max_rpc_deadline_millis_;
200 string signature_name_;
201 std::unique_ptr<PredictionServiceStubType> prediction_service_;