TensorFlow Serving C++ API Documentation
regressor.h
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TensorFlow implementation of the RegressorInterface.
17 
18 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_REGRESSOR_H_
19 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_REGRESSOR_H_
20 
21 #include <memory>
22 
23 #include "absl/types/optional.h"
24 #include "tensorflow/cc/saved_model/loader.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/threadpool_options.h"
27 #include "tensorflow_serving/apis/regressor.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 
32 // Create a new RegressorInterface backed by a TensorFlow SavedModel.
33 // Requires that the default SignatureDef be compatible with Regression.
34 Status CreateRegressorFromSavedModelBundle(
35  const RunOptions& run_options, std::unique_ptr<SavedModelBundle> bundle,
36  std::unique_ptr<RegressorInterface>* service);
37 
38 // Create a new RegressorInterface backed by a TensorFlow Session using the
39 // specified SignatureDef. Does not take ownership of the Session.
40 // Useful in contexts where we need to avoid copying, e.g. if created per
41 // request. The caller must ensure that the session and signature live at least
42 // as long as the service.
43 Status CreateFlyweightTensorFlowRegressor(
44  const RunOptions& run_options, Session* session,
45  const SignatureDef* signature,
46  std::unique_ptr<RegressorInterface>* service);
47 
48 // Similar to the above function, but with additional 'thread_pool_options'.
49 Status CreateFlyweightTensorFlowRegressor(
50  const RunOptions& run_options, Session* session,
51  const SignatureDef* signature,
52  const thread::ThreadPoolOptions& thread_pool_options,
53  std::unique_ptr<RegressorInterface>* service);
54 
55 // Get a regression signature from the meta_graph_def that's either:
56 // 1) The signature that model_spec explicitly specifies to use.
57 // 2) The default serving signature.
58 // If neither exist, or there were other issues, an error status is returned.
59 Status GetRegressionSignatureDef(const ModelSpec& model_spec,
60  const MetaGraphDef& meta_graph_def,
61  SignatureDef* signature);
62 
63 // Validate a SignatureDef to make sure it's compatible with Regression.
64 // Populate the input and output tensor names, if the args are not nullptr.
65 //
66 // NOTE: output_tensor_names may already have elements in it (e.g. when building
67 // a full list of outputs from multiple signatures), and this function will just
68 // append to the vector.
69 Status PreProcessRegression(const SignatureDef& signature,
70  string* input_tensor_name,
71  std::vector<string>* output_tensor_names);
72 
73 // Validate all results and populate a RegressionResult.
74 Status PostProcessRegressionResult(
75  const SignatureDef& signature, int num_examples,
76  const std::vector<string>& output_tensor_names,
77  const std::vector<Tensor>& output_tensors, RegressionResult* result);
78 
79 // Creates SavedModelTensorflowRegressor and runs Regression on it.
80 Status RunRegress(const RunOptions& run_options,
81  const MetaGraphDef& meta_graph_def,
82  const absl::optional<int64_t>& servable_version,
83  Session* session, const RegressionRequest& request,
84  RegressionResponse* response,
85  const thread::ThreadPoolOptions& thread_pool_options =
86  thread::ThreadPoolOptions());
87 
88 } // namespace serving
89 } // namespace tensorflow
90 
91 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_REGRESSOR_H_