16 #include "tensorflow_serving/servables/tensorflow/multi_inference.h"
21 #include "tensorflow/cc/saved_model/signature_constants.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/platform/tracing.h"
24 #include "tensorflow_serving/apis/input.pb.h"
25 #include "tensorflow_serving/apis/model.pb.h"
26 #include "tensorflow_serving/servables/tensorflow/classifier.h"
27 #include "tensorflow_serving/servables/tensorflow/regressor.h"
28 #include "tensorflow_serving/servables/tensorflow/util.h"
30 namespace tensorflow {
33 Status TensorFlowMultiInferenceRunner::Infer(
34 const RunOptions& run_options,
const MultiInferenceRequest& request,
35 MultiInferenceResponse* response) {
36 TRACELITERAL(
"TensorFlowMultiInferenceRunner::Infer");
38 string model_name =
"";
39 std::set<string> signature_names;
40 std::set<string> input_tensor_name_set;
41 std::set<string> output_tensor_name_set;
42 for (
const auto& task : request.tasks()) {
43 if (task.model_spec().name().empty()) {
44 return errors::InvalidArgument(
45 "Found ModelSpec with an empty model name.");
47 if (model_name.empty()) {
48 model_name = task.model_spec().name();
49 }
else if (model_name != task.model_spec().name()) {
50 return errors::InvalidArgument(
51 "All ModelSpecs in a MultiInferenceRequest must access the same "
55 const string signature_name = task.model_spec().signature_name().empty()
56 ? kDefaultServingSignatureDefKey
57 : task.model_spec().signature_name();
59 if (signature_names.find(signature_name) != signature_names.end()) {
60 return errors::InvalidArgument(strings::StrCat(
61 "Duplicate evaluation of signature: ", signature_name));
63 signature_names.insert(signature_name);
65 auto iter = meta_graph_def_->signature_def().find(signature_name);
66 if (iter == meta_graph_def_->signature_def().end()) {
67 return errors::InvalidArgument(strings::StrCat(
68 "Requested signature not found in model graph: ", signature_name));
71 std::vector<string> output_names;
73 if (task.method_name() == kClassifyMethodName) {
75 PreProcessClassification(iter->second, &input_name, &output_names));
76 }
else if (task.method_name() == kRegressMethodName) {
78 PreProcessRegression(iter->second, &input_name, &output_names));
80 return errors::Unimplemented(
"Unsupported signature method_name: ",
83 input_tensor_name_set.insert(input_name);
84 for (
const auto& output_tensor_name : output_names) {
85 output_tensor_name_set.insert(output_tensor_name);
89 const std::vector<string> output_tensor_names(output_tensor_name_set.begin(),
90 output_tensor_name_set.end());
92 std::vector<Tensor> outputs;
94 TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
95 run_options, request.input(), input_tensor_name_set, output_tensor_names,
96 session_, &outputs, &num_examples, thread_pool_options_));
97 RecordRequestExampleCount(model_name, num_examples);
99 TRACELITERAL(
"PostProcessResults");
100 for (
const auto& task : request.tasks()) {
101 const string signature_name = task.model_spec().signature_name().empty()
102 ? kDefaultServingSignatureDefKey
103 : task.model_spec().signature_name();
104 auto iter = meta_graph_def_->signature_def().find(signature_name);
105 if (iter == meta_graph_def_->signature_def().end()) {
106 return errors::InvalidArgument(strings::StrCat(
107 "Requested signature not found in model graph: ", signature_name));
109 if (task.method_name() == kClassifyMethodName) {
110 TF_RETURN_IF_ERROR(PostProcessClassificationResult(
111 iter->second, num_examples, output_tensor_names, outputs,
112 response->add_results()->mutable_classification_result()));
113 }
else if (task.method_name() == kRegressMethodName) {
114 TF_RETURN_IF_ERROR(PostProcessRegressionResult(
115 iter->second, num_examples, output_tensor_names, outputs,
116 response->add_results()->mutable_regression_result()));
118 return errors::InvalidArgument(
"Unrecognized signature method_name: ",
121 MakeModelSpec(task.model_spec().name(), task.model_spec().signature_name(),
123 response->mutable_results(response->results_size() - 1)
124 ->mutable_model_spec());
126 return absl::OkStatus();
129 Status RunMultiInference(
130 const RunOptions& run_options,
const MetaGraphDef& meta_graph_def,
131 const absl::optional<int64_t>& servable_version, Session* session,
132 const MultiInferenceRequest& request, MultiInferenceResponse* response,
133 const tensorflow::thread::ThreadPoolOptions& thread_pool_options) {
134 TRACELITERAL(
"RunMultiInference");
136 TensorFlowMultiInferenceRunner inference_runner(
137 session, &meta_graph_def, servable_version, thread_pool_options);
138 return inference_runner.Infer(run_options, request, response);