16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h"
21 #include <unordered_set>
25 #include "google/protobuf/wrappers.pb.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/cc/saved_model/reader.h"
31 #include "tensorflow/cc/saved_model/tag_constants.h"
32 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
33 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/protobuf/config.pb.h"
36 #include "tensorflow/core/protobuf/meta_graph.pb.h"
37 #include "tensorflow/core/public/session_options.h"
38 #include "tensorflow/core/tfrt/runtime/runtime.h"
39 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
40 #include "tsl/platform/casts.h"
41 #include "tsl/platform/env.h"
42 #include "tsl/platform/errors.h"
43 #include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
44 #include "tensorflow_serving/core/loader.h"
45 #include "tensorflow_serving/resources/resource_values.h"
46 #include "tensorflow_serving/resources/resources.pb.h"
47 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
48 #include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
49 #include "tensorflow_serving/servables/tensorflow/saved_model_config.h"
50 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
51 #include "tensorflow_serving/servables/tensorflow/servable.h"
52 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
53 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_warmup.h"
54 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
55 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
57 #include "tensorflow_serving/session_bundle/graph_rewriter.h"
59 namespace tensorflow {
63 using Batcher = SharedBatchScheduler<SavedModelBatchingTask>;
65 absl::Status WrapSavedModelForBatching(
66 const BatchingParameters& batching_config,
67 std::shared_ptr<Batcher> batch_scheduler,
68 const std::vector<std::string>& function_names,
69 std::unique_ptr<tfrt_stub::SavedModel>* saved_model) {
70 LOG(INFO) <<
"Wrapping saved model to perform batch processing";
72 if (batch_scheduler ==
nullptr) {
73 return errors::Internal(
"batch_scheduler not set");
75 if (*saved_model ==
nullptr) {
76 return errors::Internal(
"saved model not set");
80 GetQueueOptions<tensorflow::serving::SavedModelBatchingTask>(
82 [](std::unique_ptr<tensorflow::serving::SavedModelBatchingTask>*
84 int open_batch_remaining_slot,
int max_batch_size,
86 std::unique_ptr<tensorflow::serving::SavedModelBatchingTask>>*
87 output_tasks) -> absl::Status {
88 return SplitSavedModelInputTask(input_task,
89 open_batch_remaining_slot,
90 max_batch_size, output_tasks);
93 SavedModelBatchingOptions batching_saved_model_options;
94 for (
int allowed_batch_size : batching_config.allowed_batch_sizes()) {
95 batching_saved_model_options.allowed_batch_sizes.push_back(
99 batching_saved_model_options.pad_variable_length_inputs =
100 batching_config.pad_variable_length_inputs();
103 [batch_scheduler, queue_options](
104 std::function<void(std::unique_ptr<Batch<SavedModelBatchingTask>>)>
105 process_batch_callback,
106 std::unique_ptr<BatchScheduler<SavedModelBatchingTask>>* queue) {
107 TF_RETURN_IF_ERROR(batch_scheduler->AddQueue(
108 queue_options, process_batch_callback, queue));
109 return absl::OkStatus();
111 std::vector<FuncNameWithBatchingSchedulerCreator>
112 func_name_with_batching_scheduler_creator;
113 func_name_with_batching_scheduler_creator.reserve(function_names.size());
114 for (
const std::string& function_name : function_names) {
115 func_name_with_batching_scheduler_creator.push_back(
116 {function_name, create_queue});
119 return CreateSavedModelWithBatching(batching_saved_model_options,
120 func_name_with_batching_scheduler_creator,
121 std::move(*saved_model), saved_model);
124 TfrtCompileOptions::TpuAllowUnpaddedBatch ToTpuAllowUnpaddedBatch(
125 const TfrtSavedModelConfig::TpuUnpaddedBatchMode
126 tpu_unpadded_batch_mode_enum) {
127 switch (tpu_unpadded_batch_mode_enum) {
128 case TfrtSavedModelConfig::UNPADDED_BATCH_AUTO:
129 return TfrtCompileOptions::TpuAllowUnpaddedBatch::kAuto;
130 case TfrtSavedModelConfig::UNPADDED_BATCH_ENFORCED:
131 return TfrtCompileOptions::TpuAllowUnpaddedBatch::kEnforced;
132 case TfrtSavedModelConfig::UNPADDED_BATCH_DISABLED:
134 return TfrtCompileOptions::TpuAllowUnpaddedBatch::kDisabled;
138 absl::StatusOr<std::unique_ptr<TfrtSavedModelFactory>>
139 CreateDefaultTfrtSavedModelFactory(
const TfrtSavedModelConfig& config) {
140 TF_ASSIGN_OR_RETURN(
auto batcher, CreateBatchSchedulerFromConfig(config));
141 TF_ASSIGN_OR_RETURN(
auto thread_pool_factory,
142 CreateThreadPoolFactoryFromConfig(config));
144 return std::make_unique<TfrtSavedModelFactory>(
145 config, batcher, std::move(thread_pool_factory));
150 TfrtSavedModelFactory::~TfrtSavedModelFactory() =
default;
153 const TfrtSavedModelConfig& config,
154 std::unique_ptr<TfrtSavedModelFactory>* factory) {
155 auto create_fn = GetGlobalTfrtSavedModelFactoryRegistry().Get();
157 return absl::InternalError(
158 "Missing create_fn for the TfrtSavedModelFactory.");
160 TF_ASSIGN_OR_RETURN(*factory, create_fn(config));
161 return absl::OkStatus();
165 const std::string& path, ResourceAllocation* estimate)
const {
166 return EstimateResourceFromPath(
167 path, config().resource_estimation_uses_validation_result(), estimate);
172 std::unique_ptr<tfrt::SavedModel>* saved_model) {
173 std::unordered_set<std::string> saved_model_tags(
174 config().saved_model_tags().begin(), config().saved_model_tags().end());
177 if (saved_model_tags.empty()) {
178 saved_model_tags.insert(kSavedModelTagServe);
181 LOG(INFO) <<
"Creating TFRT SavedModel for path: " << path
182 <<
" with config: " << config_.DebugString();
183 auto* runtime = tensorflow::tfrt_stub::GetGlobalRuntime();
184 tfrt::SavedModel::Options options(runtime);
188 options.graph_execution_options.use_ifrt = config_.tfrt_use_ifrt();
189 TF_RETURN_IF_ERROR(RegisterCustomBackend(options.graph_execution_options));
193 tensorflow::MetaGraphDef meta_graph_def;
194 TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel(
195 std::string(path), saved_model_tags, &meta_graph_def));
196 if (
auto& graph_rewriter = tensorflow::serving::GraphRewriter::GetGlobal();
197 graph_rewriter.IsRegistered()) {
198 TF_RETURN_IF_ERROR(graph_rewriter.Get()(&meta_graph_def));
200 options.enable_lazy_loading =
201 meta_graph_def.signature_def_size() > config_.lazy_init_threshold();
202 options.maybe_load_from_mla = config_.maybe_load_from_mla();
203 options.lazy_loading_use_graph_executor =
204 config_.lazy_loading_use_graph_executor();
205 auto& compile_options = options.graph_execution_options.compile_options;
206 compile_options.enable_grappler = config_.enable_grappler();
207 compile_options.graph_options = config_.graph_options();
208 if (config_.enable_saved_model_config()) {
209 TF_RETURN_IF_ERROR(LoadSavedModelConfig(
210 path, options.graph_execution_options.compile_options.graph_options,
211 options.graph_execution_options.runtime_config));
213 if (config_.target_tpu()) {
214 compile_options.device_target = TfrtDeviceInfraTarget::kTpurt;
215 }
else if (config_.enable_tfrt_gpu()) {
216 compile_options.device_target = TfrtDeviceInfraTarget::kGpu;
218 compile_options.device_target = TfrtDeviceInfraTarget::kCpu;
220 compile_options.hoist_invariant_ops = config_.hoist_invariant_ops();
221 compile_options.sink_in_invariant_ops = config_.sink_in_invariant_ops();
222 compile_options.cost_threshold = config_.stream_merge_threshold();
223 compile_options.merge_inter_dependent_streams =
224 config_.merge_inter_dependent_streams();
225 compile_options.tpu_move_resource_gather_to_host =
226 config_.tpu_move_resource_gather_to_host();
227 compile_options.tpu_gather_table_width_threshold_bytes =
228 config_.tpu_gather_table_width_threshold_bytes();
229 compile_options.tpu_fuse_ops = config_.use_fused_tpu_op();
230 compile_options.enable_while_parallel_iterations =
231 config_.enable_while_parallel_iterations();
232 compile_options.use_tpu_host_allocator_for_inputs =
233 config_.use_tpu_host_allocator_for_inputs();
234 compile_options.tpu_allow_unpadded_batch =
235 ToTpuAllowUnpaddedBatch(config_.tpu_unpadded_batch_mode());
236 compile_options.use_gpu_compile_and_execute_op =
237 config_.tfrt_use_fused_gpu_op();
238 compile_options.min_num_batch_threads = config_.tfrt_min_num_batch_threads();
239 compile_options.min_max_enqueued_batches =
240 config_.tfrt_min_max_enqueued_batches();
241 compile_options.batch_padding_policy = config_.batch_padding_policy();
243 options.graph_execution_options.run_placer_grappler_on_functions =
244 config_.run_placer_grappler_on_functions();
245 options.graph_execution_options.enable_tfrt_gpu = config_.enable_tfrt_gpu();
246 options.graph_execution_options.tfrt_gpu_parallelism =
247 config_.tfrt_gpu_parallelism();
248 options.graph_execution_options.gpu_system_memory_size_in_mb =
249 config_.gpu_system_memory_size_in_mb();
250 options.graph_execution_options.enable_grappler_function_optimizer =
251 config_.enable_grappler_function_optimizer();
252 options.graph_execution_options.enable_online_cost_analysis =
253 config_.enable_online_cost_analysis();
254 options.graph_execution_options.enable_mlrt = config_.enable_mlrt();
255 options.graph_execution_options.model_metadata.set_name(
256 metadata.servable_id.name);
257 options.graph_execution_options.model_metadata.set_version(
258 metadata.servable_id.version);
260 TF_ASSIGN_OR_RETURN(*saved_model,
261 tfrt::SavedModelImpl::LoadSavedModel(
262 std::move(options), std::move(meta_graph_def), path));
263 if (config_.has_batching_parameters() &&
264 config_.batching_parameters().ByteSizeLong() != 0) {
265 absl::optional<BatchingParameters> batching_params;
266 TF_RETURN_IF_ERROR(GetPerModelBatchingParams(
267 path, config_.batching_parameters(),
268 config_.enable_per_model_batching_params(), &batching_params));
269 if (batching_params.has_value()) {
270 LOG(INFO) <<
"Wrapping TFRT SavedModel for batching with params: "
271 << batching_params.value().DebugString();
272 return WrapSavedModelForBatching(
273 batching_params.value(), batch_scheduler_,
274 (*saved_model)->GetFunctionNames(), saved_model);
277 return absl::OkStatus();
281 const Loader::Metadata& metadata,
const std::string& path,
282 std::unique_ptr<Servable>* servable) {
283 TF_ASSIGN_OR_RETURN(
auto override_servable, OverrideServable(metadata, path));
284 if (override_servable) {
285 *servable = std::move(override_servable);
286 return absl::OkStatus();
289 std::unique_ptr<tfrt_stub::SavedModel> saved_model;
293 MaybePublishMLMDStreamz(path, metadata.servable_id.name,
294 metadata.servable_id.version);
295 TF_ASSIGN_OR_RETURN(
auto saved_model_config,
296 LoadSavedModelConfigOrDefault(path));
298 *servable = std::make_unique<TfrtSavedModelServable>(
299 metadata.servable_id.name, metadata.servable_id.version, config_,
300 saved_model_config, std::move(saved_model), thread_pool_factory_.get(),
302 TfrtSavedModelServable* tfrt_servable =
303 down_cast<TfrtSavedModelServable*>(servable->get());
305 if (config().enable_model_warmup()) {
306 auto* warmup_options = mutable_config().mutable_model_warmup_options();
307 warmup_options->set_model_name(metadata.servable_id.name);
308 warmup_options->set_model_version(metadata.servable_id.version);
309 TF_RETURN_IF_ERROR(RunSavedModelWarmup(
310 *warmup_options, path, config().lazy_init_threshold(),
311 config().skip_warmup_requests_if_initialized(),
312 &tfrt_servable->saved_model()));
313 if (config().freeze_after_init()) {
314 TF_RETURN_IF_ERROR(Freeze(tfrt_servable->saved_model()));
318 return absl::OkStatus();
321 TfrtSavedModelFactory::TfrtSavedModelFactory(
322 const TfrtSavedModelConfig& config,
323 std::shared_ptr<Batcher> batch_scheduler,
324 std::unique_ptr<ThreadPoolFactory> thread_pool_factory,
325 std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
328 batch_scheduler_(batch_scheduler),
329 thread_pool_factory_(std::move(thread_pool_factory)),
330 recorder_creator_(std::move(recorder_creator)) {}
332 TfrtSavedModelFactoryRegistry::TfrtSavedModelFactoryRegistry() {
333 factory_create_fn_ = [](
const TfrtSavedModelConfig& config) {
334 return CreateDefaultTfrtSavedModelFactory(config);
338 absl::string_view TfrtSavedModelFactory::GetServingResourceType()
const {
339 if (std::any_of(config_.saved_model_tags().begin(),
340 config_.saved_model_tags().end(),
341 [](
const auto& tag) { return tag == kSavedModelTagTpu; })) {
342 return device_types::kTpu;
344 if (std::any_of(config_.saved_model_tags().begin(),
345 config_.saved_model_tags().end(),
346 [](
const auto& tag) { return tag == kSavedModelTagGpu; })) {
347 return device_types::kGpu;
349 return device_types::kMain;
352 TfrtSavedModelFactoryRegistry& GetGlobalTfrtSavedModelFactoryRegistry() {
353 static auto*
const registry =
new TfrtSavedModelFactoryRegistry;
357 absl::StatusOr<std::shared_ptr<TfrtSavedModelFactory::Batcher>>
358 CreateBatchSchedulerFromConfig(
const TfrtSavedModelConfig& config) {
359 std::shared_ptr<Batcher> batcher;
360 if (config.has_batching_parameters() &&
361 config.batching_parameters().ByteSizeLong() != 0) {
363 CreateBatchScheduler(config.batching_parameters(), &batcher));
368 absl::StatusOr<std::unique_ptr<ThreadPoolFactory>>
369 CreateThreadPoolFactoryFromConfig(
const TfrtSavedModelConfig& config) {
370 const auto& thread_pool_factory_config_filepath =
371 config.thread_pool_factory_config_filepath();
372 std::unique_ptr<ThreadPoolFactory> thread_pool_factory;
373 if (!thread_pool_factory_config_filepath.empty()) {
374 ThreadPoolFactoryConfig thread_pool_factory_config;
375 TF_RETURN_IF_ERROR(tsl::ReadTextProto(tsl::Env::Default(),
376 thread_pool_factory_config_filepath,
377 &thread_pool_factory_config));
378 TF_RETURN_IF_ERROR(ThreadPoolFactoryRegistry::CreateFromAny(
379 thread_pool_factory_config.thread_pool_factory_config(),
380 &thread_pool_factory));
382 return thread_pool_factory;
absl::Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)
virtual absl::Status CreateTfrtSavedModelWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< Servable > *servable)