TensorFlow Serving C++ API Documentation
serving_session.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVING_SESSION_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVING_SESSION_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/threadpool_options.h"
26 #include "tensorflow/core/public/session.h"
27 
28 namespace tensorflow {
29 namespace serving {
30 
34 class ServingSession : public Session {
35  public:
36  ServingSession() = default;
37  ~ServingSession() override = default;
38 
39  // Methods that return errors.
40  Status Create(const GraphDef& graph) final;
41  Status Extend(const GraphDef& graph) final;
42  Status Close() final;
43 
44  // (Subclasses just implement Run().)
45 };
46 
50  public:
51  explicit ServingSessionWrapper(std::unique_ptr<Session> wrapped)
52  : wrapped_(std::move(wrapped)) {
53  VLOG(2) << "Created the ServingSessionWrapper around the Session.";
54  }
55 
56  ~ServingSessionWrapper() override = default;
57 
58  Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
59  const std::vector<string>& output_tensor_names,
60  const std::vector<string>& target_node_names,
61  std::vector<Tensor>* outputs) override {
62  return wrapped_->Run(inputs, output_tensor_names, target_node_names,
63  outputs);
64  }
65 
66  Status Run(const RunOptions& run_options,
67  const std::vector<std::pair<string, Tensor>>& inputs,
68  const std::vector<string>& output_tensor_names,
69  const std::vector<string>& target_node_names,
70  std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
71  return wrapped_->Run(run_options, inputs, output_tensor_names,
72  target_node_names, outputs, run_metadata);
73  }
74 
75  Status Run(const RunOptions& run_options,
76  const std::vector<std::pair<string, Tensor>>& inputs,
77  const std::vector<string>& output_tensor_names,
78  const std::vector<string>& target_node_names,
79  std::vector<Tensor>* outputs, RunMetadata* run_metadata,
80  const thread::ThreadPoolOptions& thread_pool_options) override {
81  return wrapped_->Run(run_options, inputs, output_tensor_names,
82  target_node_names, outputs, run_metadata,
83  thread_pool_options);
84  }
85 
86  Status ListDevices(std::vector<DeviceAttributes>* response) override {
87  return wrapped_->ListDevices(response);
88  }
89 
90  private:
91  std::unique_ptr<Session> wrapped_;
92 
93  TF_DISALLOW_COPY_AND_ASSIGN(ServingSessionWrapper);
94 };
95 
96 // Subclass of SessionWrapper which reroutes Run() calls with
97 // thread_pool_options to Run() without those options. This is to provide
98 // support for RemoteSession::Run which does not implement the overloaded Run()
99 // method with thread pool options.
101  public:
103  std::unique_ptr<Session> wrapped)
104  : ServingSessionWrapper(std::move(wrapped)) {
105  VLOG(2) << "Created the SessionWrapperIgnoreThreadPoolOptions around the "
106  "Session.";
107  }
108 
109  Status Run(const RunOptions& run_options,
110  const std::vector<std::pair<string, Tensor>>& inputs,
111  const std::vector<string>& output_tensor_names,
112  const std::vector<string>& target_node_names,
113  std::vector<Tensor>* outputs, RunMetadata* run_metadata,
114  const thread::ThreadPoolOptions& thread_pool_options) override {
115  return ServingSessionWrapper::Run(run_options, inputs, output_tensor_names,
116  target_node_names, outputs, run_metadata);
117  }
118 
119  private:
120  TF_DISALLOW_COPY_AND_ASSIGN(SessionWrapperIgnoreThreadPoolOptions);
121 };
122 
123 } // namespace serving
124 } // namespace tensorflow
125 
126 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVING_SESSION_H_