16 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
23 #include <unordered_set>
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/functional/any_invocable.h"
29 #include "absl/status/status.h"
30 #include "absl/status/statusor.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/synchronization/mutex.h"
34 #include "absl/time/time.h"
35 #include "tensorflow/cc/saved_model/signature_constants.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
40 #include "tsl/platform/errors.h"
41 #include "tsl/platform/statusor.h"
42 #include "tsl/platform/threadpool_options.h"
43 #include "tensorflow_serving/apis/classification.pb.h"
44 #include "tensorflow_serving/apis/get_model_metadata.pb.h"
45 #include "tensorflow_serving/apis/inference.pb.h"
46 #include "tensorflow_serving/apis/predict.pb.h"
47 #include "tensorflow_serving/apis/regression.pb.h"
48 #include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
49 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
50 #include "tensorflow_serving/servables/tensorflow/servable.h"
51 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
52 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
53 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
54 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
55 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
58 namespace tensorflow {
61 TfrtSavedModelServable::TfrtSavedModelServable(
62 absl::string_view name, int64_t version,
const TfrtSavedModelConfig& config,
63 const SavedModelConfig& model_config,
64 std::unique_ptr<tfrt_stub::SavedModel> saved_model,
65 ThreadPoolFactory* thread_pool_factory,
66 std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
68 : Servable(name, version, model_config.critical()),
69 saved_model_(std::move(saved_model)),
71 thread_pool_factory_(thread_pool_factory),
72 recorder_creator_(std::move(recorder_creator)) {
73 switch (config_.predict_response_tensor_serialization_option()) {
74 case TfrtSavedModelConfig::AS_PROTO_FIELD: {
75 predict_response_tensor_serialization_option_ =
76 internal::PredictResponseTensorSerializationOption::kAsProtoField;
79 case TfrtSavedModelConfig::AS_PROTO_CONTENT: {
80 predict_response_tensor_serialization_option_ =
81 internal::PredictResponseTensorSerializationOption::kAsProtoContent;
85 predict_response_tensor_serialization_option_ =
86 internal::PredictResponseTensorSerializationOption::kAsProtoField;
92 tfrt_stub::SavedModel::RunOptions
93 TfrtSavedModelServable::GetTFRTSavedModelRunOptions(
94 const Servable::RunOptions& run_options)
const {
95 tfrt_stub::SavedModel::RunOptions options;
96 if (run_options.deadline != absl::InfiniteFuture()) {
97 options.deadline = absl::ToChronoTime(run_options.deadline);
99 options.validate_input_specs = config_.validate_input_specs();
100 options.validate_input_specs_dry_run = config_.validate_input_specs_dry_run();
104 absl::Status TfrtSavedModelServable::Classify(
105 const RunOptions& run_options,
const ClassificationRequest& request,
106 ClassificationResponse* response) {
107 TRACELITERAL(
"TfrtSavedModelServable::Classify");
108 auto recorder = CreateRecorder();
109 return RunClassify(GetTFRTSavedModelRunOptions(run_options), version(),
110 saved_model_.get(), request, response);
113 absl::Status TfrtSavedModelServable::Regress(
const RunOptions& run_options,
114 const RegressionRequest& request,
115 RegressionResponse* response) {
116 TRACELITERAL(
"TfrtSavedModelServable::Regress");
117 auto recorder = CreateRecorder();
118 return RunRegress(GetTFRTSavedModelRunOptions(run_options), version(),
119 saved_model_.get(), request, response);
122 absl::Status TfrtSavedModelServable::Predict(
const RunOptions& run_options,
123 const PredictRequest& request,
124 PredictResponse* response) {
125 TRACELITERAL(
"TfrtSavedModelServable::Predict");
126 auto recorder = CreateRecorder();
127 return internal::RunPredict(
128 GetTFRTSavedModelRunOptions(run_options), version(),
129 predict_response_tensor_serialization_option_, saved_model_.get(),
131 thread_pool_factory_ ==
nullptr
132 ? tsl::thread::ThreadPoolOptions()
133 : thread_pool_factory_->GetThreadPools().get());
137 absl::StatusOr<std::unique_ptr<PredictStreamedContext>>
138 TfrtSavedModelServable::PredictStreamed(
139 const RunOptions& run_options,
140 absl::AnyInvocable<
void(absl::StatusOr<PredictResponse>)>
142 auto recorder = CreateRecorder();
143 return std::make_unique<SingleRequestPredictStreamedContext>(
144 [
this, run_options, response_callback = std::move(response_callback)](
145 const PredictRequest& request)
mutable -> absl::Status {
146 TRACELITERAL(
"TfrtSavedModelServable::PredictStreamed");
148 auto tfrt_run_options = GetTFRTSavedModelRunOptions(run_options);
150 std::string signature_name =
151 request.model_spec().signature_name().empty()
152 ? kDefaultServingSignatureDefKey
153 : request.model_spec().signature_name();
155 tensorflow::serving::ModelSpec model_spec = request.model_spec();
156 model_spec.set_signature_name(signature_name);
157 model_spec.mutable_version()->set_value(version());
159 tfrt_run_options.streamed_output_callback =
160 [&](absl::flat_hash_map<std::string, tensorflow::Tensor> outputs) {
161 tensorflow::serving::PredictResponse response;
162 *response.mutable_model_spec() = model_spec;
164 for (
const auto& [output_key, output_tensor] : outputs) {
165 tensorflow::TensorProto& tensor_proto =
166 (*response.mutable_outputs())[output_key];
172 output_tensor.AsProtoField(&tensor_proto);
175 response_callback(std::move(response));
181 PredictResponse response;
183 return internal::RunPredict(
184 tfrt_run_options, version(),
185 predict_response_tensor_serialization_option_, saved_model_.get(),
187 thread_pool_factory_ ==
nullptr
188 ? tsl::thread::ThreadPoolOptions()
189 : thread_pool_factory_->GetThreadPools().get());
193 absl::Status TfrtSavedModelServable::MultiInference(
194 const RunOptions& run_options,
const MultiInferenceRequest& request,
195 MultiInferenceResponse* response) {
196 TRACELITERAL(
"TfrtSavedModelServable::MultiInference");
197 auto recorder = CreateRecorder();
198 return RunMultiInference(GetTFRTSavedModelRunOptions(run_options), version(),
199 saved_model_.get(), request, response);
202 absl::Status TfrtSavedModelServable::Suspend() {
203 TRACELITERAL(
"TfrtSavedModelServable::Suspend");
204 absl::MutexLock lock(&paging_mu_);
206 return absl::UnimplementedError(
"Suspend is not implemented");
209 return absl::OkStatus();
211 absl::Status status = suspend_fn_(
this);
218 absl::Status TfrtSavedModelServable::Resume() {
219 TRACELITERAL(
"TfrtSavedModelServable::Resume");
220 absl::MutexLock lock(&paging_mu_);
222 return absl::UnimplementedError(
"Resume is not implemented");
225 return absl::OkStatus();
227 absl::Status status = resume_fn_(
this);
236 absl::Status ValidateGetModelMetadataRequest(
237 const GetModelMetadataRequest& request) {
238 for (
const auto& metadata_field : request.metadata_field()) {
239 if (metadata_field != kSignatureDef) {
240 return absl::InvalidArgumentError(
241 absl::StrCat(
"Metadata field ", metadata_field,
" is not supported"));
244 return absl::OkStatus();
249 RequestRecorder::~RequestRecorder() =
default;
251 absl::Status TfrtSavedModelServable::GetModelMetadata(
252 const GetModelMetadataRequest& request,
253 GetModelMetadataResponse* response) {
254 TRACELITERAL(
"TfrtSavedModelServable::GetModelMetadata");
256 TF_RETURN_IF_ERROR(ValidateGetModelMetadataRequest(request));
258 for (
const auto& metadata_field : request.metadata_field()) {
259 if (metadata_field == kSignatureDef) {
260 SignatureDefMap signature_def_map;
261 for (
const auto& signature :
262 saved_model_->GetMetaGraphDef().signature_def()) {
263 (*signature_def_map.mutable_signature_def())[signature.first] =
267 auto* response_model_spec = response->mutable_model_spec();
268 response_model_spec->set_name(std::string(name()));
269 response_model_spec->mutable_version()->set_value(version());
270 (*response->mutable_metadata())[kSignatureDef].PackFrom(
273 return absl::InvalidArgumentError(
274 absl::StrCat(
"MetadataField ", metadata_field,
" is not supported"));
278 return absl::OkStatus();
281 absl::StatusOr<std::unique_ptr<TfrtSavedModelServable>>
282 CreateTfrtSavedModelServable(
283 const tensorflow::tfrt_stub::SavedModel::Options& options,
284 absl::string_view name, int64_t version, absl::string_view saved_model_dir,
285 absl::flat_hash_set<std::string> tags) {
288 tensorflow::tfrt_stub::SavedModelImpl::LoadSavedModel(
289 options, saved_model_dir,
290 std::unordered_set<std::string>(tags.begin(), tags.end())));
293 auto saved_model_config,
294 LoadSavedModelConfigOrDefault(std::string(saved_model_dir)));
296 TfrtSavedModelConfig config;
297 return std::make_unique<TfrtSavedModelServable>(
298 name, version, config, saved_model_config, std::move(saved_model),