16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SERVABLE_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SERVABLE_H_
24 #include "absl/base/thread_annotations.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/functional/any_invocable.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/synchronization/mutex.h"
31 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
32 #include "tensorflow_serving/apis/classification.pb.h"
33 #include "tensorflow_serving/apis/get_model_metadata.pb.h"
34 #include "tensorflow_serving/apis/inference.pb.h"
35 #include "tensorflow_serving/apis/predict.pb.h"
36 #include "tensorflow_serving/apis/regression.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
38 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
39 #include "tensorflow_serving/servables/tensorflow/servable.h"
40 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
41 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
43 namespace tensorflow {
59 const TfrtSavedModelConfig& config,
60 const SavedModelConfig& model_config,
61 std::unique_ptr<tfrt_stub::SavedModel> saved_model,
64 name, version, config, model_config, std::move(saved_model),
69 absl::string_view name, int64_t version,
70 const TfrtSavedModelConfig& config,
const SavedModelConfig& model_config,
71 std::unique_ptr<tfrt_stub::SavedModel> saved_model,
76 absl::Status Classify(
const RunOptions& run_options,
77 const ClassificationRequest& request,
78 ClassificationResponse* response)
override;
80 absl::Status Regress(
const RunOptions& run_options,
81 const RegressionRequest& request,
82 RegressionResponse* response)
override;
84 absl::Status Predict(
const RunOptions& run_options,
85 const PredictRequest& request,
86 PredictResponse* response)
override;
88 absl::StatusOr<std::unique_ptr<PredictStreamedContext>> PredictStreamed(
90 absl::AnyInvocable<
void(absl::StatusOr<PredictResponse>)>
91 response_callback)
override;
93 absl::Status MultiInference(
const RunOptions& run_options,
94 const MultiInferenceRequest& request,
95 MultiInferenceResponse* response)
override;
97 absl::Status GetModelMetadata(
const GetModelMetadataRequest& request,
98 GetModelMetadataResponse* response)
override;
100 bool SupportsPaging()
const override {
return true; }
102 absl::Status Suspend()
override;
104 absl::Status Resume()
override;
106 tfrt_stub::SavedModel& saved_model()
const {
return *saved_model_; }
110 absl::MutexLock lock(&paging_mu_);
111 resume_fn_ = std::move(resume_fn);
116 absl::MutexLock lock(&paging_mu_);
117 suspend_fn_ = std::move(suspend_fn);
121 tfrt_stub::SavedModel::RunOptions GetTFRTSavedModelRunOptions(
124 std::unique_ptr<RequestRecorder> CreateRecorder() {
125 return recorder_creator_(*
this);
128 std::unique_ptr<tfrt_stub::SavedModel> saved_model_;
132 TfrtSavedModelConfig config_;
134 internal::PredictResponseTensorSerializationOption
135 predict_response_tensor_serialization_option_ =
136 internal::PredictResponseTensorSerializationOption::kAsProtoField;
147 ABSL_GUARDED_BY(paging_mu_);
150 ABSL_GUARDED_BY(paging_mu_);
152 bool suspended_ ABSL_GUARDED_BY(paging_mu_) =
false;
154 absl::Mutex paging_mu_;
158 absl::StatusOr<std::unique_ptr<TfrtSavedModelServable>>
159 CreateTfrtSavedModelServable(
160 const tensorflow::tfrt_stub::SavedModel::Options& options,
161 absl::string_view name, int64_t version, absl::string_view saved_model_dir,
162 absl::flat_hash_set<std::string> tags);