17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
21 namespace tensorflow {
23 REGISTER_OP(
"TfServingRemotePredict")
24 .Attr(
"T: list(type)")
25 .Attr(
"target_address: string = ''")
26 .Attr(
"model_name: string = ''")
27 .Attr(
"model_version: int = -1")
28 .Attr(
"fail_op_on_rpc_error: bool = true")
29 .Attr(
"max_rpc_deadline_millis: int = 30000")
30 .Attr(
"signature_name: string = 'serving_default'")
31 .Input(
"input_tensor_aliases: string")
32 .Input(
"input_tensors: T")
33 .Input(
"output_tensor_aliases: string")
34 .Output(
"status_code: int32")
35 .Output(
"status_error_message: string")
36 .Output(
"output_tensors: output_types")
37 .Attr(
"output_types: list(type)")
38 .SetShapeFn([](shape_inference::InferenceContext* c) {
39 shape_inference::ShapeHandle unused;
41 std::vector<shape_inference::ShapeHandle> input_aliases_handle;
43 c->input(
"input_tensor_aliases", &input_aliases_handle));
44 TF_RETURN_IF_ERROR(c->WithRank(input_aliases_handle[0], 1, &unused));
45 std::vector<shape_inference::ShapeHandle> inputs_handle;
46 TF_RETURN_IF_ERROR(c->input(
"input_tensors", &inputs_handle));
47 if (c->Value(c->NumElements(input_aliases_handle[0])) !=
48 inputs_handle.size()) {
49 return errors::InvalidArgument(
50 "'input_tensors' should be equal in length to "
51 "'input_tensor_aliases'. Length of 'input_tensors': ",
52 inputs_handle.size(),
", length of 'input_tensor_aliases': ",
53 c->Value(c->NumElements(input_aliases_handle[0])));
57 DataTypeVector output_types;
58 TF_RETURN_IF_ERROR(c->GetAttr(
"output_types", &output_types));
59 std::vector<shape_inference::ShapeHandle> output_aliases_handle;
61 c->input(
"output_tensor_aliases", &output_aliases_handle));
62 if (c->Value(c->NumElements(output_aliases_handle[0])) !=
63 output_types.size()) {
64 return errors::InvalidArgument(
65 "'output_types' should be equal in length to "
66 "'output_tensor_aliases'. Length of 'output_types': ",
67 output_types.size(),
", length of 'output_tensor_aliases': ",
68 c->Value(c->NumElements(output_aliases_handle[0])));
72 TF_RETURN_IF_ERROR(c->set_output(
"status_code", {c->Scalar()}));
73 TF_RETURN_IF_ERROR(c->set_output(
"status_error_message", {c->Scalar()}));
74 for (
int i = 2; i < c->num_outputs(); ++i) {
75 c->set_output(i, c->UnknownShape());
81 .SetIsDistributedCommunication()
83 Invokes Predict on a remote graph.
84 fail_op_on_rpc_error: If set true, the Op fails if the rpc fails, and returns
85 the status code as 0 and an empty status_message. Otherwise the
86 Op returns the status of the rpc call, along with the output tensors, if any.
88 max_rpc_deadline_millis: The rpc deadline for remote predict. The actual
89 deadline is min(incoming_rpc_deadline, max_rpc_deadline_millis).
90 signature_name: the signature def for remote graph inference, defaulting to
92 target_address: Address of the server hosting the remote graph.
93 model_name: Model name of the remote TF graph.
94 model_version: the target version for the Predict call. When unset, the
95 default value (-1) implies the latest available version should be used.
96 input_tensor_aliases: Tensor of strings for the input tensor alias names to supply
97 to the RemotePredict call.
98 input_tensors: List of tensors to provide as input. Should be equal in length
99 to 'input_tensor_aliases'.
100 output_tensor_aliases: Tensor of strings for the output tensor alias names to
101 supply to the Predict call.
102 status_code: Returns the status code of the rpc call; basically converting
103 tensorflow::error::Code to it's int value, so 0 means OK.
104 status_error_message: Returns the error message in the rpc status.
105 output_tensors: Tensors returned by the Predict call on the remote graph, which
106 are in the same order as output_tensor_aliases.
107 output_types: A list of types of the output tensors. Length of this list should
108 be equal to the length of 'output_tensor_aliases'.