TensorFlow Serving C++ API Documentation
remote_predict_op.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 
17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 REGISTER_OP("TfServingRemotePredict")
24  .Attr("T: list(type)")
25  .Attr("target_address: string = ''")
26  .Attr("model_name: string = ''")
27  .Attr("model_version: int = -1")
28  .Attr("fail_op_on_rpc_error: bool = true")
29  .Attr("max_rpc_deadline_millis: int = 30000")
30  .Attr("signature_name: string = 'serving_default'")
31  .Input("input_tensor_aliases: string")
32  .Input("input_tensors: T")
33  .Input("output_tensor_aliases: string")
34  .Output("status_code: int32")
35  .Output("status_error_message: string")
36  .Output("output_tensors: output_types")
37  .Attr("output_types: list(type)")
38  .SetShapeFn([](shape_inference::InferenceContext* c) {
39  shape_inference::ShapeHandle unused;
40  // Checks the length of input_tensor_aliases with that of input_tensors.
41  std::vector<shape_inference::ShapeHandle> input_aliases_handle;
42  TF_RETURN_IF_ERROR(
43  c->input("input_tensor_aliases", &input_aliases_handle));
44  TF_RETURN_IF_ERROR(c->WithRank(input_aliases_handle[0], 1, &unused));
45  std::vector<shape_inference::ShapeHandle> inputs_handle;
46  TF_RETURN_IF_ERROR(c->input("input_tensors", &inputs_handle));
47  if (c->Value(c->NumElements(input_aliases_handle[0])) !=
48  inputs_handle.size()) {
49  return errors::InvalidArgument(
50  "'input_tensors' should be equal in length to "
51  "'input_tensor_aliases'. Length of 'input_tensors': ",
52  inputs_handle.size(), ", length of 'input_tensor_aliases': ",
53  c->Value(c->NumElements(input_aliases_handle[0])));
54  }
55 
56  // Checks the length of output_tensor_aliases with that of output_types.
57  DataTypeVector output_types;
58  TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types));
59  std::vector<shape_inference::ShapeHandle> output_aliases_handle;
60  TF_RETURN_IF_ERROR(
61  c->input("output_tensor_aliases", &output_aliases_handle));
62  if (c->Value(c->NumElements(output_aliases_handle[0])) !=
63  output_types.size()) {
64  return errors::InvalidArgument(
65  "'output_types' should be equal in length to "
66  "'output_tensor_aliases'. Length of 'output_types': ",
67  output_types.size(), ", length of 'output_tensor_aliases': ",
68  c->Value(c->NumElements(output_aliases_handle[0])));
69  }
70 
71  // We know the shape of the first 2 outputs, but not the rest.
72  TF_RETURN_IF_ERROR(c->set_output("status_code", {c->Scalar()}));
73  TF_RETURN_IF_ERROR(c->set_output("status_error_message", {c->Scalar()}));
74  for (int i = 2; i < c->num_outputs(); ++i) {
75  c->set_output(i, c->UnknownShape());
76  }
77 
78  return Status();
79  })
80  .SetIsStateful()
81  .SetIsDistributedCommunication()
82  .Doc(R"doc(
83 Invokes Predict on a remote graph.
84 fail_op_on_rpc_error: If set true, the Op fails if the rpc fails, and returns
85  the status code as 0 and an empty status_message. Otherwise the
86  Op returns the status of the rpc call, along with the output tensors, if any.
87  Set true by default.
88 max_rpc_deadline_millis: The rpc deadline for remote predict. The actual
89 deadline is min(incoming_rpc_deadline, max_rpc_deadline_millis).
90 signature_name: the signature def for remote graph inference, defaulting to
91 "serving_default".
92 target_address: Address of the server hosting the remote graph.
93 model_name: Model name of the remote TF graph.
94 model_version: the target version for the Predict call. When unset, the
95  default value (-1) implies the latest available version should be used.
96 input_tensor_aliases: Tensor of strings for the input tensor alias names to supply
97  to the RemotePredict call.
98 input_tensors: List of tensors to provide as input. Should be equal in length
99  to 'input_tensor_aliases'.
100 output_tensor_aliases: Tensor of strings for the output tensor alias names to
101  supply to the Predict call.
102 status_code: Returns the status code of the rpc call; basically converting
103  tensorflow::error::Code to it's int value, so 0 means OK.
104 status_error_message: Returns the error message in the rpc status.
105 output_tensors: Tensors returned by the Predict call on the remote graph, which
106  are in the same order as output_tensor_aliases.
107 output_types: A list of types of the output tensors. Length of this list should
108  be equal to the length of 'output_tensor_aliases'.
109 )doc");
110 
111 } // namespace tensorflow