16 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/cc/saved_model/util.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/protobuf/named_tensor.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/util.h"
34 namespace tensorflow {
38 Status VerifySignature(
const SignatureDef& signature) {
39 if (GetSignatureMethodNameCheckFeature() &&
40 signature.method_name() != kPredictMethodName &&
41 signature.method_name() != kClassifyMethodName &&
42 signature.method_name() != kRegressMethodName) {
43 return errors::Internal(strings::StrCat(
44 "Expected prediction signature method_name to be one of {",
45 kPredictMethodName,
", ", kClassifyMethodName,
", ", kRegressMethodName,
46 "}. Was: ", signature.method_name()));
48 return absl::OkStatus();
51 Status VerifyRequestInputsSize(
const SignatureDef& signature,
52 const PredictRequest& request) {
53 if (request.inputs().size() > signature.inputs().size() ||
54 (request.inputs().size() < signature.inputs().size() &&
55 signature.defaults().empty())) {
56 const std::set<string> request_inputs = GetMapKeys(request.inputs());
57 const std::set<string> signature_inputs = GetMapKeys(signature.inputs());
58 const std::set<string> sent_extra =
59 SetDifference(request_inputs, signature_inputs);
60 const std::set<string> missing =
61 SetDifference(signature_inputs, request_inputs);
62 return tensorflow::Status(
63 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
65 "input size does not match signature: ", request.inputs().size(),
66 "!=", signature.inputs().size(),
" len({",
67 absl::StrJoin(request_inputs,
","),
"}) != len({",
68 absl::StrJoin(signature_inputs,
","),
"}). Sent extra: {",
69 absl::StrJoin(sent_extra,
","),
"}. Missing but required: {",
70 absl::StrJoin(missing,
","),
"}."));
72 return absl::OkStatus();
79 const RunOptions& run_options,
const MetaGraphDef& meta_graph_def,
80 const absl::optional<int64_t>& servable_version,
81 const internal::PredictResponseTensorSerializationOption option,
82 Session* session,
const PredictRequest& request, PredictResponse* response,
83 const thread::ThreadPoolOptions& thread_pool_options) {
85 const string signature_name = request.model_spec().signature_name().empty()
86 ? kDefaultServingSignatureDefKey
87 : request.model_spec().signature_name();
88 auto iter = meta_graph_def.signature_def().find(signature_name);
89 if (iter == meta_graph_def.signature_def().end()) {
90 return errors::FailedPrecondition(strings::StrCat(
91 "Serving signature key \"", signature_name,
"\" not found."));
93 const SignatureDef& signature = iter->second;
95 MakeModelSpec(request.model_spec().name(), signature_name, servable_version,
96 response->mutable_model_spec());
98 std::vector<std::pair<string, Tensor>> input_tensors;
99 std::vector<string> output_tensor_names;
100 std::vector<string> output_tensor_aliases;
101 TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
102 &output_tensor_names,
103 &output_tensor_aliases));
104 std::vector<Tensor> outputs;
105 RunMetadata run_metadata;
106 const uint64_t start_microseconds = EnvTime::NowMicros();
107 TF_RETURN_IF_ERROR(session->Run(run_options, input_tensors,
108 output_tensor_names, {}, &outputs,
109 &run_metadata, thread_pool_options));
110 const uint64_t end_microseconds = EnvTime::NowMicros();
111 RecordRuntimeLatency(request.model_spec().name(),
"Predict",
113 end_microseconds - start_microseconds);
115 return PostProcessPredictionResult(output_tensor_aliases, outputs, option,
119 Status PreProcessPrediction(
const SignatureDef& signature,
120 const PredictRequest& request,
121 std::vector<std::pair<string, Tensor>>* inputs,
122 std::vector<string>* output_tensor_names,
123 std::vector<string>* output_tensor_aliases) {
124 TF_RETURN_IF_ERROR(VerifySignature(signature));
125 TF_RETURN_IF_ERROR(VerifyRequestInputsSize(signature, request));
127 saved_model::GetInputValues(signature, request.inputs(), *inputs));
130 std::set<string> seen_outputs;
131 std::vector<string> output_filter(request.output_filter().begin(),
132 request.output_filter().end());
133 for (
auto& alias : output_filter) {
134 auto iter = signature.outputs().find(alias);
135 if (iter == signature.outputs().end()) {
136 return tensorflow::Status(
137 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
138 strings::StrCat(
"output tensor alias not found in signature: ", alias,
139 " Outputs expected to be in the set {",
140 absl::StrJoin(GetMapKeys(signature.outputs()),
","),
143 if (seen_outputs.find(alias) != seen_outputs.end()) {
144 return tensorflow::Status(
145 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
146 "duplicate output tensor alias: " + alias);
148 seen_outputs.insert(alias);
149 output_tensor_names->emplace_back(iter->second.name());
150 output_tensor_aliases->emplace_back(alias);
154 if (output_tensor_names->empty()) {
155 for (
auto& iter : signature.outputs()) {
156 output_tensor_names->emplace_back(iter.second.name());
157 output_tensor_aliases->emplace_back(iter.first);
160 return absl::OkStatus();
163 Status PostProcessPredictionResult(
164 const std::vector<string>& output_tensor_aliases,
165 const std::vector<Tensor>& output_tensors,
166 const internal::PredictResponseTensorSerializationOption option,
167 PredictResponse* response) {
169 if (output_tensors.size() != output_tensor_aliases.size()) {
170 return tensorflow::Status(
171 static_cast<tensorflow::errors::Code
>(absl::StatusCode::kUnknown),
172 "Predict internal error");
175 case internal::PredictResponseTensorSerializationOption::kAsProtoField: {
176 for (
int i = 0; i < output_tensors.size(); i++) {
177 output_tensors[i].AsProtoField(
178 &((*response->mutable_outputs())[output_tensor_aliases[i]]));
181 case internal::PredictResponseTensorSerializationOption::kAsProtoContent: {
182 for (
int i = 0; i < output_tensors.size(); i++) {
183 output_tensors[i].AsProtoTensorContent(
184 &((*response->mutable_outputs())[output_tensor_aliases[i]]));
189 return absl::OkStatus();
194 Status RunPredict(
const RunOptions& run_options,
195 const MetaGraphDef& meta_graph_def,
196 const absl::optional<int64_t>& servable_version,
197 Session* session,
const PredictRequest& request,
198 PredictResponse* response,
199 const thread::ThreadPoolOptions& thread_pool_options) {
200 return internal::RunPredict(
201 run_options, meta_graph_def, servable_version,
202 internal::PredictResponseTensorSerializationOption::kAsProtoField,
203 session, request, response, thread_pool_options);