16 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/cc/saved_model/util.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/protobuf/error_codes.pb.h"
33 #include "tensorflow/core/protobuf/named_tensor.pb.h"
34 #include "tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h"
35 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
36 #include "tsl/platform/error_logging.h"
37 #include "tensorflow_serving/apis/predict.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
39 #include "tensorflow_serving/servables/tensorflow/util.h"
41 namespace tensorflow {
46 Status PreProcessPredictionWithoutOutputFilter(
47 const tfrt::FunctionMetadata& function_metadata,
48 const PredictRequest& request, std::vector<Tensor>* input_tensors) {
49 input_tensors->reserve(function_metadata.GetInputNames().size());
50 for (
int i = 0; i < function_metadata.GetInputNames().size(); ++i) {
51 const auto& input_name = function_metadata.GetInputNames()[i];
52 const auto input = request.inputs().find(input_name);
53 if (input == request.inputs().end()) {
54 const auto& default_inputs = function_metadata.GetDefaultInputs();
55 const auto& default_input = default_inputs.find(input_name);
56 if (default_input == default_inputs.end()) {
57 const std::set<string> request_inputs = GetMapKeys(request.inputs());
58 const std::set<string> required_inputs(
59 function_metadata.GetInputNames().begin(),
60 function_metadata.GetInputNames().end());
61 const std::set<string> sent_extra =
62 SetDifference(request_inputs, required_inputs);
63 const std::set<string> missing =
64 SetDifference(SetDifference(required_inputs, request_inputs),
65 saved_model::GetMapKeys(default_inputs));
66 return errors::InvalidArgument(absl::StrCat(
67 "Request inputs do not match required inputs for model `",
68 request.model_spec().name(),
"`. Send extra: {",
69 absl::StrJoin(sent_extra,
","),
"}. Missing but required: {",
70 absl::StrJoin(missing,
","),
"}."));
73 if (!tensor.FromProto(default_input->second)) {
74 return errors::InvalidArgument(
75 absl::StrCat(
"tensor parsing error: ", input_name));
77 input_tensors->emplace_back(std::move(tensor));
81 if (!tensor.FromProto(input->second)) {
82 return errors::InvalidArgument(
83 absl::StrCat(
"tensor parsing error: ", input_name));
85 const auto expected_dtype = function_metadata.GetInputSpecs()[i].dtype;
87 if (expected_dtype != DT_INVALID
88 && tensor.dtype() != expected_dtype) {
89 return errors::InvalidArgument(
90 absl::StrCat(
"Expected input ", input_name,
" to be ",
91 DataTypeString(expected_dtype),
" but get ",
92 DataTypeString(tensor.dtype()),
"."));
94 input_tensors->emplace_back(std::move(tensor));
96 return absl::OkStatus();
101 Status PostProcessPredictionResultWithoutOutputFilter(
102 const std::vector<string>& output_tensor_names,
103 const std::vector<Tensor>& output_tensors,
104 const internal::PredictResponseTensorSerializationOption option,
105 const PredictRequest& request, PredictResponse* response) {
106 if (output_tensor_names.size() != output_tensors.size()) {
107 return errors::Unknown(
"Predict internal error.");
110 std::unordered_set<string> output_filter(request.output_filter().begin(),
111 request.output_filter().end());
113 for (
int i = 0; i < output_tensors.size(); ++i) {
114 if (!output_filter.empty() &&
115 output_filter.find(output_tensor_names[i]) == output_filter.end()) {
119 case internal::PredictResponseTensorSerializationOption::kAsProtoField: {
120 output_tensors[i].AsProtoField(
121 &((*response->mutable_outputs())[output_tensor_names[i]]));
123 case internal::PredictResponseTensorSerializationOption::
125 output_tensors[i].AsProtoTensorContent(
126 &((*response->mutable_outputs())[output_tensor_names[i]]));
132 if (!output_filter.empty() && output_filter.size() != output_size) {
133 return errors::InvalidArgument(absl::StrCat(
134 "output_filter contains non-existed output names. output_filter: ",
135 absl::StrJoin(output_filter,
",")));
137 return absl::OkStatus();
140 bool IsOutputFilterEmptyOrFullSet(
141 const PredictRequest& request,
142 const tfrt::FunctionMetadata& function_metadata) {
143 if (request.output_filter().empty())
return true;
144 if (request.output_filter().size() !=
145 function_metadata.GetOutputNames().size())
147 std::vector<absl::string_view> output_filter_names(
148 request.output_filter().begin(), request.output_filter().end());
149 std::vector<absl::string_view> func_output_names(
150 function_metadata.GetOutputNames().begin(),
151 function_metadata.GetOutputNames().end());
152 std::sort(output_filter_names.begin(), output_filter_names.end());
153 std::sort(func_output_names.begin(), func_output_names.end());
154 return output_filter_names == func_output_names;
161 const tfrt::SavedModel::RunOptions& run_options,
162 const absl::optional<int64_t>& servable_version,
163 const internal::PredictResponseTensorSerializationOption option,
164 tfrt::SavedModel* saved_model,
const PredictRequest& request,
165 PredictResponse* response,
166 const thread::ThreadPoolOptions& thread_pool_options) {
168 const std::string function_name =
169 request.model_spec().signature_name().empty()
170 ? kDefaultServingSignatureDefKey
171 : request.model_spec().signature_name();
173 const auto function_metadata =
174 saved_model->GetFunctionMetadata(function_name);
175 if (!function_metadata.has_value()) {
176 return errors::FailedPrecondition(
177 strings::StrCat(
"Function \"", function_name,
"\" not found."));
180 MakeModelSpec(request.model_spec().name(), function_name, servable_version,
181 response->mutable_model_spec());
183 auto run_opts = run_options;
184 std::optional<tensorflow::tfrt_stub::TfThreadPoolWorkQueue> thread_pool;
185 if (thread_pool_options.inter_op_threadpool !=
nullptr) {
187 thread_pool_options.intra_op_threadpool,
188 thread_pool_options.inter_op_threadpool);
189 run_opts.work_queue = &(*thread_pool);
192 if (IsOutputFilterEmptyOrFullSet(request, function_metadata.value())) {
194 std::vector<Tensor> input_tensors;
195 TF_RETURN_IF_ERROR(PreProcessPredictionWithoutOutputFilter(
196 function_metadata.value(), request, &input_tensors));
199 std::vector<Tensor> outputs;
200 const uint64_t start_microseconds = EnvTime::NowMicros();
201 if (
const auto status =
202 saved_model->Run(run_opts, function_name, input_tensors, &outputs);
204 if (IsTfrtErrorLoggingEnabled()) {
205 tsl::error_logging::Log(
"TFRT",
"SavedModelRun", status.message())
210 const uint64_t end_microseconds = EnvTime::NowMicros();
211 RecordRuntimeLatency(request.model_spec().name(),
"Predict",
213 end_microseconds - start_microseconds);
216 return PostProcessPredictionResultWithoutOutputFilter(
217 function_metadata->GetOutputNames(), outputs, option, request,
224 const auto& metagraph_def = saved_model->GetMetaGraphDef();
225 auto iter = metagraph_def.signature_def().find(function_name);
226 if (iter == metagraph_def.signature_def().end()) {
227 return errors::FailedPrecondition(strings::StrCat(
228 "Serving signature key \"", function_name,
"\" not found."));
230 const SignatureDef& signature = iter->second;
232 std::vector<std::pair<string, Tensor>> input_tensors;
233 std::vector<string> output_tensor_names;
234 std::vector<string> output_tensor_aliases;
235 TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
236 &output_tensor_names,
237 &output_tensor_aliases));
239 const uint64_t start_microseconds = EnvTime::NowMicros();
240 std::vector<Tensor> outputs;
241 if (
const auto status = saved_model->RunByTensorNames(
242 run_opts, input_tensors, output_tensor_names,
245 if (IsTfrtErrorLoggingEnabled()) {
246 tsl::error_logging::Log(
"TFRT",
"SavedModelRun", status.message())
251 const uint64_t end_microseconds = EnvTime::NowMicros();
252 RecordRuntimeLatency(request.model_spec().name(),
"Predict",
254 end_microseconds - start_microseconds);
256 return PostProcessPredictionResult(output_tensor_aliases, outputs, option,
262 Status RunPredict(
const tfrt::SavedModel::RunOptions& run_options,
263 const absl::optional<int64_t>& servable_version,
264 tfrt::SavedModel* saved_model,
const PredictRequest& request,
265 PredictResponse* response,
266 const thread::ThreadPoolOptions& thread_pool_options) {
267 return internal::RunPredict(
268 run_options, servable_version,
269 internal::PredictResponseTensorSerializationOption::kAsProtoField,
270 saved_model, request, response, thread_pool_options);