16 #include "tensorflow_serving/servables/tensorflow/classifier.h"
27 #include "tensorflow/cc/saved_model/signature_constants.h"
28 #include "tensorflow/core/example/example.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/platform/threadpool_options.h"
35 #include "tensorflow/core/platform/tracing.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow_serving/apis/classification.pb.h"
38 #include "tensorflow_serving/apis/classifier.h"
39 #include "tensorflow_serving/apis/input.pb.h"
40 #include "tensorflow_serving/apis/model.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
43 namespace tensorflow {
48 class SavedModelTensorFlowClassifier :
public ClassifierInterface {
50 explicit SavedModelTensorFlowClassifier(
51 const RunOptions& run_options, Session* session,
52 const SignatureDef*
const signature,
53 const thread::ThreadPoolOptions& thread_pool_options =
54 thread::ThreadPoolOptions())
55 : run_options_(run_options),
57 signature_(signature),
58 thread_pool_options_(thread_pool_options) {}
60 ~SavedModelTensorFlowClassifier()
override =
default;
62 Status Classify(
const ClassificationRequest& request,
63 ClassificationResult* result)
override {
64 TRACELITERAL(
"TensorFlowClassifier::Classify");
66 string input_tensor_name;
67 std::vector<string> output_tensor_names;
68 TF_RETURN_IF_ERROR(PreProcessClassification(*signature_, &input_tensor_name,
69 &output_tensor_names));
71 std::vector<Tensor> outputs;
73 int64_t runtime_latency;
74 TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
75 run_options_, request.input(), input_tensor_name, output_tensor_names,
76 session_, &outputs, &num_examples, thread_pool_options_,
78 RecordRuntimeLatency(request.model_spec().name(),
"Classify",
79 "TF1", runtime_latency);
81 TRACELITERAL(
"ConvertToClassificationResult");
82 return PostProcessClassificationResult(
83 *signature_, num_examples, output_tensor_names, outputs, result);
87 const RunOptions run_options_;
88 Session*
const session_;
89 const SignatureDef*
const signature_;
90 const thread::ThreadPoolOptions thread_pool_options_;
92 TF_DISALLOW_COPY_AND_ASSIGN(SavedModelTensorFlowClassifier);
95 class SavedModelClassifier :
public ClassifierInterface {
97 SavedModelClassifier(
const RunOptions& run_options,
98 std::unique_ptr<SavedModelBundle> bundle)
99 : run_options_(run_options), bundle_(std::move(bundle)) {}
101 ~SavedModelClassifier()
override =
default;
103 Status Classify(
const ClassificationRequest& request,
104 ClassificationResult* result)
override {
109 SignatureDef signature;
110 TF_RETURN_IF_ERROR(GetClassificationSignatureDef(
111 request.model_spec(), bundle_->meta_graph_def, &signature));
112 SavedModelTensorFlowClassifier classifier(
113 run_options_, bundle_->session.get(), &signature);
114 return classifier.Classify(request, result);
118 const RunOptions run_options_;
119 std::unique_ptr<SavedModelBundle> bundle_;
121 TF_DISALLOW_COPY_AND_ASSIGN(SavedModelClassifier);
126 Status CreateClassifierFromSavedModelBundle(
127 const RunOptions& run_options, std::unique_ptr<SavedModelBundle> bundle,
128 std::unique_ptr<ClassifierInterface>* service) {
129 service->reset(
new SavedModelClassifier(run_options, std::move(bundle)));
130 return absl::OkStatus();
133 Status CreateFlyweightTensorFlowClassifier(
134 const RunOptions& run_options, Session* session,
135 const SignatureDef* signature,
136 std::unique_ptr<ClassifierInterface>* service) {
137 return CreateFlyweightTensorFlowClassifier(
138 run_options, session, signature, thread::ThreadPoolOptions(), service);
141 Status CreateFlyweightTensorFlowClassifier(
142 const RunOptions& run_options, Session* session,
143 const SignatureDef* signature,
144 const thread::ThreadPoolOptions& thread_pool_options,
145 std::unique_ptr<ClassifierInterface>* service) {
146 service->reset(
new SavedModelTensorFlowClassifier(
147 run_options, session, signature, thread_pool_options));
148 return absl::OkStatus();
151 Status GetClassificationSignatureDef(
const ModelSpec& model_spec,
152 const MetaGraphDef& meta_graph_def,
153 SignatureDef* signature) {
154 const string signature_name = model_spec.signature_name().empty()
155 ? kDefaultServingSignatureDefKey
156 : model_spec.signature_name();
157 auto iter = meta_graph_def.signature_def().find(signature_name);
158 if (iter == meta_graph_def.signature_def().end()) {
159 return errors::InvalidArgument(strings::StrCat(
160 "No signature was found with the name: ", signature_name));
162 if (GetSignatureMethodNameCheckFeature()) {
163 if (iter->second.method_name() != kClassifyMethodName) {
164 return errors::InvalidArgument(strings::StrCat(
165 "Expected classification signature method_name to be ",
166 kClassifyMethodName,
". Was: ", iter->second.method_name()));
170 PreProcessClassification(iter->second,
nullptr,
nullptr));
172 *signature = iter->second;
173 return absl::OkStatus();
176 Status PreProcessClassification(
const SignatureDef& signature,
177 string* input_tensor_name,
178 std::vector<string>* output_tensor_names) {
179 if (GetSignatureMethodNameCheckFeature() &&
180 signature.method_name() != kClassifyMethodName) {
181 return errors::InvalidArgument(strings::StrCat(
182 "Expected classification signature method_name to be ",
183 kClassifyMethodName,
". Was: ", signature.method_name()));
185 if (signature.inputs().size() != 1) {
186 return errors::InvalidArgument(
187 strings::StrCat(
"Expected one input Tensor."));
189 if (signature.outputs().size() != 1 && signature.outputs().size() != 2) {
190 return errors::InvalidArgument(
191 strings::StrCat(
"Expected one or two output Tensors, found ",
192 signature.outputs().size()));
195 auto input_iter = signature.inputs().find(kClassifyInputs);
196 if (input_iter == signature.inputs().end()) {
197 return errors::InvalidArgument(
198 "No classification inputs found in SignatureDef: ",
199 signature.DebugString());
201 if (input_tensor_name !=
nullptr) {
202 *input_tensor_name = input_iter->second.name();
205 auto classes_iter = signature.outputs().find(kClassifyOutputClasses);
206 auto scores_iter = signature.outputs().find(kClassifyOutputScores);
207 if (classes_iter == signature.outputs().end() &&
208 scores_iter == signature.outputs().end()) {
209 return errors::InvalidArgument(strings::StrCat(
210 "Expected classification signature outputs to contain at least one of ",
211 "\"", kClassifyOutputClasses,
"\" or \"", kClassifyOutputScores,
212 "\". Signature was: ", signature.DebugString()));
214 if (output_tensor_names !=
nullptr) {
215 if (classes_iter != signature.outputs().end()) {
216 output_tensor_names->push_back(classes_iter->second.name());
218 if (scores_iter != signature.outputs().end()) {
219 output_tensor_names->push_back(scores_iter->second.name());
222 return absl::OkStatus();
225 Status PostProcessClassificationResult(
226 const SignatureDef& signature,
int num_examples,
227 const std::vector<string>& output_tensor_names,
228 const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
229 if (output_tensors.size() != output_tensor_names.size()) {
230 return errors::InvalidArgument(
231 strings::StrCat(
"Expected ", output_tensor_names.size(),
232 " output tensor(s). Got: ", output_tensors.size()));
235 auto classes_iter = signature.outputs().find(kClassifyOutputClasses);
236 string classes_tensor_name;
237 if (classes_iter != signature.outputs().end()) {
238 classes_tensor_name = classes_iter->second.name();
240 auto scores_iter = signature.outputs().find(kClassifyOutputScores);
241 string scores_tensor_name;
242 if (scores_iter != signature.outputs().end()) {
243 scores_tensor_name = scores_iter->second.name();
246 const Tensor* classes =
nullptr;
247 const Tensor* scores =
nullptr;
248 for (
int i = 0; i < output_tensors.size(); ++i) {
249 if (output_tensor_names[i] == classes_tensor_name) {
250 classes = &output_tensors[i];
251 }
else if (output_tensor_names[i] == scores_tensor_name) {
252 scores = &output_tensors[i];
258 if (classes->dims() != 2) {
259 return errors::InvalidArgument(
260 "Expected Tensor shape: [batch_size num_classes] but got ",
261 classes->shape().DebugString());
263 if (classes->dtype() != DT_STRING) {
264 return errors::InvalidArgument(
265 "Expected classes Tensor of DT_STRING. Got: ",
266 DataType_Name(classes->dtype()));
268 if (classes->dim_size(0) != num_examples) {
269 return errors::InvalidArgument(
"Expected classes output batch size of ",
271 ". Got: ", classes->dim_size(0));
276 if (scores->dims() != 2) {
277 return errors::InvalidArgument(
278 "Expected Tensor shape: [batch_size num_classes] but got ",
279 scores->shape().DebugString());
281 if (scores->dtype() != DT_FLOAT) {
282 return errors::InvalidArgument(
283 "Expected scores Tensor of DT_FLOAT. Got: ",
284 DataType_Name(scores->dtype()));
286 if (scores->dim_size(0) != num_examples) {
287 return errors::InvalidArgument(
"Expected scores output batch size of ",
289 ". Got: ", scores->dim_size(0));
295 if (classes && scores) {
297 if (classes->dim_size(1) != scores->dim_size(1)) {
298 return errors::InvalidArgument(
299 "Tensors class and score should match in dim_size(1). Got ",
300 classes->dim_size(1),
" vs. ", scores->dim_size(1));
302 num_classes = classes->dim_size(1);
303 }
else if (classes) {
304 num_classes = classes->dim_size(1);
306 num_classes = scores->dim_size(1);
310 for (
int i = 0; i < num_examples; ++i) {
311 serving::Classifications* classifications = result->add_classifications();
312 for (
int c = 0; c < num_classes; ++c) {
313 serving::Class* cl = classifications->add_classes();
315 const tstring& class_tstr = (classes->matrix<tstring>())(i, c);
316 cl->set_label(class_tstr.data(), class_tstr.size());
319 cl->set_score((scores->matrix<
float>())(i, c));
323 return absl::OkStatus();
326 Status RunClassify(
const RunOptions& run_options,
327 const MetaGraphDef& meta_graph_def,
328 const absl::optional<int64_t>& servable_version,
329 Session* session,
const ClassificationRequest& request,
330 ClassificationResponse* response,
331 const thread::ThreadPoolOptions& thread_pool_options) {
332 SignatureDef signature;
333 TF_RETURN_IF_ERROR(GetClassificationSignatureDef(request.model_spec(),
334 meta_graph_def, &signature));
336 std::unique_ptr<ClassifierInterface> classifier_interface;
337 TF_RETURN_IF_ERROR(CreateFlyweightTensorFlowClassifier(
338 run_options, session, &signature, thread_pool_options,
339 &classifier_interface));
341 MakeModelSpec(request.model_spec().name(),
342 request.model_spec().signature_name(), servable_version,
343 response->mutable_model_spec());
346 return classifier_interface->Classify(request, response->mutable_result());