TensorFlow Serving C++ API Documentation
predict_util.h
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_PREDICT_UTIL_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_PREDICT_UTIL_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/lib/core/status.h"
21 #include "tensorflow/core/platform/threadpool_options.h"
22 #include "tensorflow/core/protobuf/config.pb.h"
23 #include "tensorflow/core/protobuf/meta_graph.pb.h"
24 #include "tensorflow/core/public/session.h"
25 #include "tensorflow_serving/apis/predict.pb.h"
26 #include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
27 
28 namespace tensorflow {
29 namespace serving {
30 
31 namespace internal {
32 
33 // Similar to RunPredict below, but allows specification of a serialization
34 // option for the TensorProtos in the response.
35 Status RunPredict(
36  const RunOptions& run_options, const MetaGraphDef& meta_graph_def,
37  const absl::optional<int64_t>& servable_version,
38  const PredictResponseTensorSerializationOption tensor_serialization_option,
39  Session* session, const PredictRequest& request, PredictResponse* response,
40  const thread::ThreadPoolOptions& thread_pool_options =
41  thread::ThreadPoolOptions());
42 
43 // Validate a SignatureDef to make sure it's compatible with prediction, and
44 // if so, populate the input and output tensor names.
45 Status PreProcessPrediction(const SignatureDef& signature,
46  const PredictRequest& request,
47  std::vector<std::pair<string, Tensor>>* inputs,
48  std::vector<string>* output_tensor_names,
49  std::vector<string>* output_tensor_aliases);
50 
51 // Validate results and populate a PredictResponse.
52 // Tensors are serialized as specified.
53 Status PostProcessPredictionResult(
54  const std::vector<string>& output_tensor_aliases,
55  const std::vector<Tensor>& output_tensors,
56  const internal::PredictResponseTensorSerializationOption option,
57  PredictResponse* response);
58 
59 } // namespace internal
60 
61 // Implementation of Predict using the SavedModel SignatureDef format.
62 //
63 // IMPLEMENTATION NOTES: Calls the internal::RunPredict function above by
64 // specifying serialization option as kAsProtoField for backward compatibility.
65 Status RunPredict(const RunOptions& run_options,
66  const MetaGraphDef& meta_graph_def,
67  const absl::optional<int64_t>& servable_version,
68  Session* session, const PredictRequest& request,
69  PredictResponse* response,
70  const thread::ThreadPoolOptions& thread_pool_options =
71  thread::ThreadPoolOptions());
72 
73 } // namespace serving
74 } // namespace tensorflow
75 
76 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_PREDICT_UTIL_H_