16 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
18 #include "tensorflow/cc/saved_model/loader.h"
19 #include "tensorflow/cc/saved_model/signature_constants.h"
20 #include "tensorflow/core/tfrt/utils/tensor_util.h"
21 #include "tsl/platform/error_logging.h"
22 #include "tensorflow_serving/apis/input.pb.h"
23 #include "tensorflow_serving/apis/model.pb.h"
24 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
25 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
26 #include "tensorflow_serving/servables/tensorflow/util.h"
28 namespace tensorflow {
31 Status RunMultiInference(
const tfrt::SavedModel::RunOptions& run_options,
32 const absl::optional<int64_t>& servable_version,
33 tfrt::SavedModel* saved_model,
34 const MultiInferenceRequest& request,
35 MultiInferenceResponse* response) {
38 InputToSerializedExampleTensor(request.input(), &input_tensor));
39 std::vector<std::vector<Tensor>> input_tensors;
40 int num_examples = input_tensor.dim_size(0);
41 input_tensors.resize(request.tasks_size());
42 for (
int i = 0; i < request.tasks_size(); ++i) {
43 input_tensors[i].emplace_back(input_tensor);
47 std::string model_name =
"";
48 std::set<std::string> function_names_set;
49 std::vector<std::string> function_names;
50 function_names.reserve(request.tasks_size());
51 for (
const auto& task : request.tasks()) {
52 if (task.model_spec().name().empty()) {
53 return errors::InvalidArgument(
54 "Found ModelSpec with an empty model name.");
56 if (model_name.empty()) {
57 model_name = task.model_spec().name();
58 }
else if (model_name != task.model_spec().name()) {
59 return errors::InvalidArgument(
60 "All ModelSpecs in a MultiInferenceRequest must access the same "
64 const std::string function_name = task.model_spec().signature_name().empty()
65 ? kDefaultServingSignatureDefKey
66 : task.model_spec().signature_name();
70 if (function_names_set.find(function_name) != function_names_set.end()) {
71 return errors::InvalidArgument(strings::StrCat(
72 "Duplicate evaluation of signature: ", function_name));
74 function_names_set.insert(function_name);
75 function_names.push_back(function_name);
77 const auto function_metadata =
78 saved_model->GetFunctionMetadata(function_name);
79 if (!function_metadata.has_value()) {
80 return errors::InvalidArgument(
81 strings::StrCat(
"Function \"", function_name,
"\" not found."));
84 if (task.method_name() == kClassifyMethodName) {
85 TF_RETURN_IF_ERROR(PreProcessClassification(function_metadata.value()));
86 }
else if (task.method_name() == kRegressMethodName) {
87 TF_RETURN_IF_ERROR(PreProcessRegression(function_metadata.value()));
89 return errors::Unimplemented(
"Unsupported signature method_name: ",
95 std::vector<std::vector<Tensor>> output_tensors;
96 if (
const auto status = saved_model->RunMultipleSignatures(
97 run_options, function_names, input_tensors, &output_tensors);
99 if (IsTfrtErrorLoggingEnabled()) {
100 tsl::error_logging::Log(
"TFRT",
"SavedModelRun", status.message())
107 for (
int i = 0; i < request.tasks_size(); ++i) {
110 const auto function_metadata =
111 saved_model->GetFunctionMetadata(function_names[i]);
112 DCHECK(function_metadata.has_value());
113 if (request.tasks(i).method_name() == kClassifyMethodName) {
114 TF_RETURN_IF_ERROR(PostProcessClassificationResult(
115 num_examples, function_metadata->GetOutputNames(), output_tensors[i],
116 response->add_results()->mutable_classification_result()));
117 }
else if (request.tasks(i).method_name() == kRegressMethodName) {
118 TF_RETURN_IF_ERROR(PostProcessRegressionResult(
119 num_examples, function_metadata->GetOutputNames(), output_tensors[i],
120 response->add_results()->mutable_regression_result()));
122 return errors::InvalidArgument(
"Unrecognized signature method_name: ",
123 request.tasks(i).method_name());
125 MakeModelSpec(request.tasks(i).model_spec().name(),
126 request.tasks(i).model_spec().signature_name(),
128 response->mutable_results(response->results_size() - 1)
129 ->mutable_model_spec());
131 RecordRequestExampleCount(model_name, num_examples);
132 return absl::OkStatus();