16 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/core/example/example.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/notification.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/threadpool_options.h"
33 #include "tensorflow/core/platform/tracing.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/tfrt/utils/tensor_util.h"
36 #include "tsl/platform/error_logging.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/regression.pb.h"
40 #include "tensorflow_serving/apis/regressor.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
43 namespace tensorflow {
46 Status PreProcessRegression(
const tfrt::FunctionMetadata& function_metadata) {
47 if (function_metadata.GetInputNames().size() != 1) {
48 return errors::InvalidArgument(
49 strings::StrCat(
"Expected one input Tensor."));
51 if (function_metadata.GetOutputNames().size() != 1) {
52 return errors::InvalidArgument(
53 strings::StrCat(
"Expected one output Tensor."));
56 if (function_metadata.GetInputNames()[0] != kRegressInputs) {
57 return errors::FailedPrecondition(
58 "No regression inputs found in function's metadata, only contains: ",
59 function_metadata.GetInputNames()[0]);
62 if (function_metadata.GetOutputNames()[0] != kRegressOutputs) {
63 return errors::FailedPrecondition(
64 "No regression outputs found in function's metadata, only contains: ",
65 function_metadata.GetOutputNames()[0]);
68 return absl::OkStatus();
71 Status PostProcessRegressionResult(
72 int num_examples,
const std::vector<string>& output_tensor_names,
73 const std::vector<Tensor>& output_tensors, RegressionResult* result) {
74 if (output_tensors.size() != output_tensor_names.size()) {
75 return errors::InvalidArgument(
76 "Expected output_tensors and output_tensor_names to have the same "
80 const Tensor* output_tensor = &output_tensors[0];
82 if (!(output_tensor->dims() == 1 ||
83 (output_tensor->dims() == 2 && output_tensor->dim_size(1) == 1))) {
84 return errors::InvalidArgument(
85 "Expected output Tensor shape to be either [batch_size] or ",
86 "[batch_size, 1] but got ", output_tensor->shape().DebugString());
88 if (num_examples != output_tensor->dim_size(0)) {
89 return errors::InvalidArgument(strings::StrCat(
90 "Input batch size did not match output batch size: ", num_examples,
91 " vs. ", output_tensor->dim_size(0)));
93 if (output_tensor->dtype() != DT_FLOAT) {
94 return errors::InvalidArgument(
"Expected output Tensor of DT_FLOAT. Got: ",
95 DataType_Name(output_tensor->dtype()));
98 if (output_tensor->NumElements() != num_examples) {
99 return errors::InvalidArgument(
"Expected output batch size to be ",
101 ". Got: ", output_tensor->NumElements());
104 const auto& output_tensor_flat = output_tensor->flat<
float>();
105 for (
int i = 0; i < num_examples; ++i) {
106 result->add_regressions()->set_value(output_tensor_flat(i));
108 return absl::OkStatus();
111 Status RunRegress(
const tfrt::SavedModel::RunOptions& run_options,
112 const absl::optional<int64_t>& servable_version,
113 tfrt::SavedModel* saved_model,
114 const RegressionRequest& request,
115 RegressionResponse* response) {
116 const string function_name = request.model_spec().signature_name().empty()
117 ? kDefaultServingSignatureDefKey
118 : request.model_spec().signature_name();
120 const auto function_metadata =
121 saved_model->GetFunctionMetadata(function_name);
122 if (!function_metadata.has_value()) {
123 return errors::FailedPrecondition(
124 strings::StrCat(
"Function \"", function_name,
"\" not found."));
127 MakeModelSpec(request.model_spec().name(), function_name, servable_version,
128 response->mutable_model_spec());
131 TF_RETURN_IF_ERROR(PreProcessRegression(function_metadata.value()));
134 InputToSerializedExampleTensor(request.input(), &input_tensor));
135 std::vector<Tensor> input_tensors;
136 int num_examples = input_tensor.dim_size(0);
137 input_tensors.emplace_back(std::move(input_tensor));
140 std::vector<Tensor> output_tensors;
141 const uint64_t start_microseconds = EnvTime::NowMicros();
142 if (
const auto status = saved_model->Run(run_options, function_name,
143 input_tensors, &output_tensors);
145 if (IsTfrtErrorLoggingEnabled()) {
146 tsl::error_logging::Log(
"TFRT",
"SavedModelRun", status.message())
151 const uint64_t end_microseconds = EnvTime::NowMicros();
152 RecordRuntimeLatency(request.model_spec().name(),
154 end_microseconds - start_microseconds);
157 return PostProcessRegressionResult(
158 num_examples, function_metadata->GetOutputNames(), output_tensors,
159 response->mutable_result());