16 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
18 #include "tensorflow/cc/saved_model/signature_constants.h"
19 #include "tensorflow/core/example/example.pb.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/threadpool.h"
25 #include "tensorflow/core/platform/threadpool_options.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/tfrt/utils/tensor_util.h"
29 #include "tsl/platform/error_logging.h"
30 #include "tensorflow_serving/apis/classification.pb.h"
31 #include "tensorflow_serving/apis/classifier.h"
32 #include "tensorflow_serving/apis/input.pb.h"
33 #include "tensorflow_serving/apis/model.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/util.h"
36 namespace tensorflow {
39 Status PreProcessClassification(
40 const tfrt::FunctionMetadata& function_metadata) {
41 if (function_metadata.GetInputNames().size() != 1) {
42 return errors::InvalidArgument(
43 strings::StrCat(
"Expected one input Tensor."));
45 if (function_metadata.GetOutputNames().size() != 1 &&
46 function_metadata.GetOutputNames().size() != 2) {
47 return errors::InvalidArgument(
48 strings::StrCat(
"Expected one or two output Tensors, found ",
49 function_metadata.GetOutputNames().size()));
52 if (function_metadata.GetInputNames()[0] != kClassifyInputs) {
53 return errors::FailedPrecondition(
54 "No classification inputs found in function's metadata, only "
56 function_metadata.GetInputNames()[0]);
59 bool find_output_classes =
false;
60 bool find_output_scores =
false;
61 for (
const std::string& output_name : function_metadata.GetOutputNames()) {
62 if (output_name == kClassifyOutputClasses) {
63 find_output_classes =
true;
64 }
else if ((output_name == kClassifyOutputScores)) {
65 find_output_scores =
true;
69 if ((function_metadata.GetOutputNames().size() == 1 && !find_output_classes &&
70 !find_output_scores) ||
71 (function_metadata.GetOutputNames().size() == 2 &&
72 !(find_output_classes && find_output_scores))) {
73 return errors::FailedPrecondition(strings::StrCat(
74 "Expected classification function outputs to contain",
"\"",
75 kClassifyOutputClasses,
"\" and/or \"", kClassifyOutputScores,
"\". "));
78 return absl::OkStatus();
81 Status PostProcessClassificationResult(
82 int num_examples,
const std::vector<string>& output_names,
83 const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
84 if (output_tensors.size() != output_names.size()) {
85 return errors::InvalidArgument(strings::StrCat(
86 "Unexpected output tensors size. Expected ", output_names.size(),
87 " output tensor(s). Got: ", output_tensors.size()));
90 const Tensor* classes =
nullptr;
91 const Tensor* scores =
nullptr;
92 for (
int i = 0; i < output_tensors.size(); ++i) {
93 if (output_names[i] == kClassifyOutputClasses) {
94 classes = &output_tensors[i];
95 }
else if (output_names[i] == kClassifyOutputScores) {
96 scores = &output_tensors[i];
102 if (classes->dims() != 2) {
103 return errors::InvalidArgument(
104 "Expected Tensor shape: [batch_size num_classes] but got ",
105 classes->shape().DebugString());
107 if (classes->dtype() != DT_STRING) {
108 return errors::InvalidArgument(
109 "Expected classes Tensor of DT_STRING. Got: ",
110 DataType_Name(classes->dtype()));
112 if (classes->dim_size(0) != num_examples) {
113 return errors::InvalidArgument(
"Expected classes output batch size of ",
115 ". Got: ", classes->dim_size(0));
120 if (scores->dims() != 2) {
121 return errors::InvalidArgument(
122 "Expected Tensor shape: [batch_size num_classes] but got ",
123 scores->shape().DebugString());
125 if (scores->dtype() != DT_FLOAT) {
126 return errors::InvalidArgument(
127 "Expected scores Tensor of DT_FLOAT. Got: ",
128 DataType_Name(scores->dtype()));
130 if (scores->dim_size(0) != num_examples) {
131 return errors::InvalidArgument(
"Expected scores output batch size of ",
133 ". Got: ", scores->dim_size(0));
139 if (classes && scores) {
141 if (classes->dim_size(1) != scores->dim_size(1)) {
142 return errors::InvalidArgument(
143 "Tensors class and score should match in dim_size(1). Got ",
144 classes->dim_size(1),
" vs. ", scores->dim_size(1));
146 num_classes = classes->dim_size(1);
147 }
else if (classes) {
148 num_classes = classes->dim_size(1);
150 num_classes = scores->dim_size(1);
154 for (
int i = 0; i < num_examples; ++i) {
155 serving::Classifications* classifications = result->add_classifications();
156 for (
int c = 0; c < num_classes; ++c) {
157 serving::Class* cl = classifications->add_classes();
159 const tstring& class_tstr = (classes->matrix<tstring>())(i, c);
160 cl->set_label(class_tstr.data(), class_tstr.size());
163 cl->set_score((scores->matrix<
float>())(i, c));
167 return absl::OkStatus();
170 Status RunClassify(
const tfrt::SavedModel::RunOptions& run_options,
171 const absl::optional<int64_t>& servable_version,
172 tfrt::SavedModel* saved_model,
173 const ClassificationRequest& request,
174 ClassificationResponse* response) {
175 const string function_name = request.model_spec().signature_name().empty()
176 ? kDefaultServingSignatureDefKey
177 : request.model_spec().signature_name();
179 const auto function_metadata =
180 saved_model->GetFunctionMetadata(function_name);
181 if (!function_metadata.has_value()) {
182 return errors::FailedPrecondition(
183 strings::StrCat(
"Function \"", function_name,
"\" not found."));
186 MakeModelSpec(request.model_spec().name(), function_name, servable_version,
187 response->mutable_model_spec());
190 TF_RETURN_IF_ERROR(PreProcessClassification(function_metadata.value()));
193 InputToSerializedExampleTensor(request.input(), &input_tensor));
194 std::vector<Tensor> input_tensors;
195 int num_examples = input_tensor.dim_size(0);
196 input_tensors.emplace_back(std::move(input_tensor));
199 std::vector<Tensor> output_tensors;
200 const uint64_t start_microseconds = EnvTime::NowMicros();
201 if (
const auto status = saved_model->Run(run_options, function_name,
202 input_tensors, &output_tensors);
204 if (IsTfrtErrorLoggingEnabled()) {
205 tsl::error_logging::Log(
"TFRT",
"SavedModelRun", status.message())
210 const uint64_t end_microseconds = EnvTime::NowMicros();
211 RecordRuntimeLatency(request.model_spec().name(),
"Classify",
213 end_microseconds - start_microseconds);
216 return PostProcessClassificationResult(
217 num_examples, function_metadata->GetOutputNames(), output_tensors,
218 response->mutable_result());