16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
23 #include "absl/status/status.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/tfrt/runtime/runtime.h"
30 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
31 #include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
32 #include "tensorflow_serving/core/loader.h"
33 #include "tensorflow_serving/resources/resources.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
36 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
37 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
39 namespace tensorflow {
51 using Batcher = SharedBatchScheduler<SavedModelBatchingTask>;
54 std::shared_ptr<Batcher> batch_scheduler,
55 std::unique_ptr<ThreadPoolFactory> thread_pool_factory)
57 std::move(thread_pool_factory),
62 const TfrtSavedModelConfig& config,
63 std::shared_ptr<Batcher> batch_scheduler,
64 std::unique_ptr<ThreadPoolFactory> thread_pool_factory,
74 static absl::Status
Create(
const TfrtSavedModelConfig& config,
75 std::unique_ptr<TfrtSavedModelFactory>* factory);
85 std::unique_ptr<Servable>* servable);
87 ABSL_DEPRECATED(
"Use the overload that creates Servable instead")
89 const
Loader::Metadata& metadata, const
string& path,
90 std::unique_ptr<tfrt_stub::SavedModel>* saved_model);
99 ResourceAllocation* estimate) const;
101 const TfrtSavedModelConfig& config()
const {
return config_; }
102 TfrtSavedModelConfig& mutable_config() {
return config_; }
103 absl::string_view GetServingResourceType()
const;
109 virtual absl::StatusOr<std::unique_ptr<Servable>> OverrideServable(
110 const Loader::Metadata& metadata,
const std::string& path) {
116 virtual absl::Status RegisterCustomBackend(
117 tfrt_stub::GraphExecutionOptions& options) {
118 return absl::OkStatus();
121 virtual absl::Status Freeze(tfrt_stub::SavedModel& saved_model) {
122 return absl::OkStatus();
125 TfrtSavedModelConfig config_;
129 std::shared_ptr<Batcher> batch_scheduler_;
133 std::unique_ptr<ThreadPoolFactory> thread_pool_factory_;
135 std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
136 recorder_creator_ = [](TfrtSavedModelServable&) {
return nullptr; };
138 TF_DISALLOW_COPY_AND_ASSIGN(TfrtSavedModelFactory);
148 std::function<absl::StatusOr<std::unique_ptr<TfrtSavedModelFactory>>(
149 const TfrtSavedModelConfig& config)>;
153 void Register(CreateFn fn) {
154 absl::MutexLock lock(&mu_);
155 if (factory_create_fn_) {
156 LOG(INFO) <<
"Overriding TfrtSavedModelFactory's create function.";
158 factory_create_fn_ = std::move(fn);
161 CreateFn Get()
const {
162 absl::MutexLock lock(&mu_);
163 return factory_create_fn_;
167 mutable absl::Mutex mu_;
168 CreateFn factory_create_fn_ ABSL_GUARDED_BY(mu_);
173 absl::StatusOr<std::shared_ptr<TfrtSavedModelFactory::Batcher>>
174 CreateBatchSchedulerFromConfig(
const TfrtSavedModelConfig& config);
178 absl::StatusOr<std::unique_ptr<ThreadPoolFactory>>
179 CreateThreadPoolFactoryFromConfig(
const TfrtSavedModelConfig& config);
absl::Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)
virtual absl::Status CreateTfrtSavedModelWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< Servable > *servable)