TensorFlow Serving C++ API Documentation
tfrt_saved_model_factory.h
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 
23 #include "absl/status/status.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/tfrt/runtime/runtime.h"
30 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
31 #include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
32 #include "tensorflow_serving/core/loader.h"
33 #include "tensorflow_serving/resources/resources.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
36 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
37 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
38 
39 namespace tensorflow {
40 namespace serving {
41 
50  public:
51  using Batcher = SharedBatchScheduler<SavedModelBatchingTask>;
52 
53  TfrtSavedModelFactory(const TfrtSavedModelConfig& config,
54  std::shared_ptr<Batcher> batch_scheduler,
55  std::unique_ptr<ThreadPoolFactory> thread_pool_factory)
56  : TfrtSavedModelFactory(config, std::move(batch_scheduler),
57  std::move(thread_pool_factory),
58  [](TfrtSavedModelServable&) { return nullptr; }) {
59  }
60 
62  const TfrtSavedModelConfig& config,
63  std::shared_ptr<Batcher> batch_scheduler,
64  std::unique_ptr<ThreadPoolFactory> thread_pool_factory,
65  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
66  recorder_creator);
67 
68  virtual ~TfrtSavedModelFactory();
69 
74  static absl::Status Create(const TfrtSavedModelConfig& config,
75  std::unique_ptr<TfrtSavedModelFactory>* factory);
76 
83  virtual absl::Status CreateTfrtSavedModelWithMetadata(
84  const Loader::Metadata& metadata, const string& path,
85  std::unique_ptr<Servable>* servable);
86 
87  ABSL_DEPRECATED("Use the overload that creates Servable instead")
89  const Loader::Metadata& metadata, const string& path,
90  std::unique_ptr<tfrt_stub::SavedModel>* saved_model);
91 
98  absl::Status EstimateResourceRequirement(const string& path,
99  ResourceAllocation* estimate) const;
100 
101  const TfrtSavedModelConfig& config() const { return config_; }
102  TfrtSavedModelConfig& mutable_config() { return config_; }
103  absl::string_view GetServingResourceType() const;
104 
105  private:
106  // The subclass can override this method to return a custom servable
107  // instead of creating one using CreateTfrtSavedModelWithMetadata(). If it
108  // returns nullptr, CreateTfrtSavedModelWithMetadata() will be used normally.
109  virtual absl::StatusOr<std::unique_ptr<Servable>> OverrideServable(
110  const Loader::Metadata& metadata, const std::string& path) {
111  return nullptr;
112  }
113 
114  // The subclass can override this method to register the custom backend into
115  // TFRT savedmodel.
116  virtual absl::Status RegisterCustomBackend(
117  tfrt_stub::GraphExecutionOptions& options) {
118  return absl::OkStatus();
119  }
120 
121  virtual absl::Status Freeze(tfrt_stub::SavedModel& saved_model) {
122  return absl::OkStatus();
123  }
124 
125  TfrtSavedModelConfig config_;
126 
127  // A shared batch scheduler. One queue is used for each saved model this
128  // factory emits. If batching is not configured, this remains null.
129  std::shared_ptr<Batcher> batch_scheduler_;
130 
131  // `thread_pool_factory_` is used to create inter-op ThreadPools. It can be a
132  // nullptr and then the default Tensorflow threadpools should be used.
133  std::unique_ptr<ThreadPoolFactory> thread_pool_factory_;
134 
135  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
136  recorder_creator_ = [](TfrtSavedModelServable&) { return nullptr; };
137 
138  TF_DISALLOW_COPY_AND_ASSIGN(TfrtSavedModelFactory);
139 };
140 
141 // The registry for creating the TfrtSavedModelFactory. By default the CreateFn
142 // creates an instance of TfrtSavedModelFactory. Custom implementations can use
143 // this registry to override the CreateFn so that it creates an instance of the
144 // subclass of TfrtSavedModelFactory.
146  public:
147  using CreateFn =
148  std::function<absl::StatusOr<std::unique_ptr<TfrtSavedModelFactory>>(
149  const TfrtSavedModelConfig& config)>;
150 
152 
153  void Register(CreateFn fn) {
154  absl::MutexLock lock(&mu_);
155  if (factory_create_fn_) {
156  LOG(INFO) << "Overriding TfrtSavedModelFactory's create function.";
157  }
158  factory_create_fn_ = std::move(fn);
159  }
160 
161  CreateFn Get() const {
162  absl::MutexLock lock(&mu_);
163  return factory_create_fn_;
164  }
165 
166  private:
167  mutable absl::Mutex mu_;
168  CreateFn factory_create_fn_ ABSL_GUARDED_BY(mu_);
169 };
170 
171 // Creates a batch scheduler based on `config`. The result can be a nullptr if
172 // `config` does not specify batch parameters.
173 absl::StatusOr<std::shared_ptr<TfrtSavedModelFactory::Batcher>>
174 CreateBatchSchedulerFromConfig(const TfrtSavedModelConfig& config);
175 
176 // Creates a thread pool factory based on `config`. The result can be a nullptr
177 // if `config` does not specify thread pool factory config.
178 absl::StatusOr<std::unique_ptr<ThreadPoolFactory>>
179 CreateThreadPoolFactoryFromConfig(const TfrtSavedModelConfig& config);
180 
181 TfrtSavedModelFactoryRegistry& GetGlobalTfrtSavedModelFactoryRegistry();
182 
183 } // namespace serving
184 } // namespace tensorflow
185 
186 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
absl::Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)
virtual absl::Status CreateTfrtSavedModelWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< Servable > *servable)
The metadata consists of the ServableId.
Definition: loader.h:94