TensorFlow Serving C++ API Documentation
classifier.h
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TensorFlow implementation of the ClassifierInterface.
17 
18 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_CLASSIFIER_H_
19 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_CLASSIFIER_H_
20 
21 #include <memory>
22 
23 #include "absl/types/optional.h"
24 #include "tensorflow/cc/saved_model/loader.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/threadpool_options.h"
27 #include "tensorflow_serving/apis/classifier.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 
32 // Create a new ClassifierInterface backed by a TensorFlow SavedModel.
33 // Requires that the default SignatureDef be compatible with classification.
34 Status CreateClassifierFromSavedModelBundle(
35  const RunOptions& run_options, std::unique_ptr<SavedModelBundle> bundle,
36  std::unique_ptr<ClassifierInterface>* service);
37 
38 // Create a new ClassifierInterface backed by a TensorFlow Session using the
39 // specified SignatureDef. Does not take ownership of the Session.
40 // Useful in contexts where we need to avoid copying, e.g. if created per
41 // request. The caller must ensure that the session and signature live at least
42 // as long as the service.
43 Status CreateFlyweightTensorFlowClassifier(
44  const RunOptions& run_options, Session* session,
45  const SignatureDef* signature,
46  std::unique_ptr<ClassifierInterface>* service);
47 
48 // Similar to the above function, but with an additional 'thread_pool_options'.
49 Status CreateFlyweightTensorFlowClassifier(
50  const RunOptions& run_options, Session* session,
51  const SignatureDef* signature,
52  const thread::ThreadPoolOptions& thread_pool_options,
53  std::unique_ptr<ClassifierInterface>* service);
54 
55 // Get a classification signature from the meta_graph_def that's either:
56 // 1) The signature that model_spec explicitly specifies to use.
57 // 2) The default serving signature.
58 // If neither exist, or there were other issues, an error status is returned.
59 Status GetClassificationSignatureDef(const ModelSpec& model_spec,
60  const MetaGraphDef& meta_graph_def,
61  SignatureDef* signature);
62 
63 // Validate a SignatureDef to make sure it's compatible with classification.
64 // Populate the input and output tensor names, if the args are not nullptr.
65 //
66 // NOTE: output_tensor_names may already have elements in it (e.g. when building
67 // a full list of outputs from multiple signatures), and this function will just
68 // append to the vector.
69 Status PreProcessClassification(const SignatureDef& signature,
70  string* input_tensor_name,
71  std::vector<string>* output_tensor_names);
72 
73 // Validate all results and populate a ClassificationResult.
74 Status PostProcessClassificationResult(
75  const SignatureDef& signature, int num_examples,
76  const std::vector<string>& output_tensor_names,
77  const std::vector<Tensor>& output_tensors, ClassificationResult* result);
78 
79 // Creates SavedModelTensorflowClassifier and runs Classification on it.
80 Status RunClassify(const RunOptions& run_options,
81  const MetaGraphDef& meta_graph_def,
82  const absl::optional<int64_t>& servable_version,
83  Session* session, const ClassificationRequest& request,
84  ClassificationResponse* response,
85  const thread::ThreadPoolOptions& thread_pool_options =
86  thread::ThreadPoolOptions());
87 
88 } // namespace serving
89 } // namespace tensorflow
90 
91 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_CLASSIFIER_H_