TensorFlow Serving C++ API Documentation
tfrt_saved_model_factory.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "google/protobuf/wrappers.pb.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/cc/saved_model/reader.h"
31 #include "tensorflow/cc/saved_model/tag_constants.h"
32 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
33 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/protobuf/config.pb.h"
36 #include "tensorflow/core/protobuf/meta_graph.pb.h"
37 #include "tensorflow/core/public/session_options.h"
38 #include "tensorflow/core/tfrt/runtime/runtime.h"
39 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
40 #include "tsl/platform/casts.h"
41 #include "tsl/platform/env.h"
42 #include "tsl/platform/errors.h"
43 #include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
44 #include "tensorflow_serving/core/loader.h"
45 #include "tensorflow_serving/resources/resource_values.h"
46 #include "tensorflow_serving/resources/resources.pb.h"
47 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
48 #include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
49 #include "tensorflow_serving/servables/tensorflow/saved_model_config.h"
50 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
51 #include "tensorflow_serving/servables/tensorflow/servable.h"
52 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
53 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_warmup.h"
54 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
55 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
57 #include "tensorflow_serving/session_bundle/graph_rewriter.h"
58 
59 namespace tensorflow {
60 namespace serving {
61 namespace {
62 
63 using Batcher = SharedBatchScheduler<SavedModelBatchingTask>;
64 
65 absl::Status WrapSavedModelForBatching(
66  const BatchingParameters& batching_config,
67  std::shared_ptr<Batcher> batch_scheduler,
68  const std::vector<std::string>& function_names,
69  std::unique_ptr<tfrt_stub::SavedModel>* saved_model) {
70  LOG(INFO) << "Wrapping saved model to perform batch processing";
71 
72  if (batch_scheduler == nullptr) {
73  return errors::Internal("batch_scheduler not set");
74  }
75  if (*saved_model == nullptr) {
76  return errors::Internal("saved model not set");
77  }
78 
79  auto queue_options =
80  GetQueueOptions<tensorflow::serving::SavedModelBatchingTask>(
81  batching_config,
82  [](std::unique_ptr<tensorflow::serving::SavedModelBatchingTask>*
83  input_task,
84  int open_batch_remaining_slot, int max_batch_size,
85  std::vector<
86  std::unique_ptr<tensorflow::serving::SavedModelBatchingTask>>*
87  output_tasks) -> absl::Status {
88  return SplitSavedModelInputTask(input_task,
89  open_batch_remaining_slot,
90  max_batch_size, output_tasks);
91  });
92 
93  SavedModelBatchingOptions batching_saved_model_options;
94  for (int allowed_batch_size : batching_config.allowed_batch_sizes()) {
95  batching_saved_model_options.allowed_batch_sizes.push_back(
96  allowed_batch_size);
97  }
98 
99  batching_saved_model_options.pad_variable_length_inputs =
100  batching_config.pad_variable_length_inputs();
101 
102  auto create_queue =
103  [batch_scheduler, queue_options](
104  std::function<void(std::unique_ptr<Batch<SavedModelBatchingTask>>)>
105  process_batch_callback,
106  std::unique_ptr<BatchScheduler<SavedModelBatchingTask>>* queue) {
107  TF_RETURN_IF_ERROR(batch_scheduler->AddQueue(
108  queue_options, process_batch_callback, queue));
109  return absl::OkStatus();
110  };
111  std::vector<FuncNameWithBatchingSchedulerCreator>
112  func_name_with_batching_scheduler_creator;
113  func_name_with_batching_scheduler_creator.reserve(function_names.size());
114  for (const std::string& function_name : function_names) {
115  func_name_with_batching_scheduler_creator.push_back(
116  {function_name, create_queue});
117  }
118 
119  return CreateSavedModelWithBatching(batching_saved_model_options,
120  func_name_with_batching_scheduler_creator,
121  std::move(*saved_model), saved_model);
122 }
123 
124 TfrtCompileOptions::TpuAllowUnpaddedBatch ToTpuAllowUnpaddedBatch(
125  const TfrtSavedModelConfig::TpuUnpaddedBatchMode
126  tpu_unpadded_batch_mode_enum) {
127  switch (tpu_unpadded_batch_mode_enum) {
128  case TfrtSavedModelConfig::UNPADDED_BATCH_AUTO:
129  return TfrtCompileOptions::TpuAllowUnpaddedBatch::kAuto;
130  case TfrtSavedModelConfig::UNPADDED_BATCH_ENFORCED:
131  return TfrtCompileOptions::TpuAllowUnpaddedBatch::kEnforced;
132  case TfrtSavedModelConfig::UNPADDED_BATCH_DISABLED:
133  default:
134  return TfrtCompileOptions::TpuAllowUnpaddedBatch::kDisabled;
135  }
136 }
137 
138 absl::StatusOr<std::unique_ptr<TfrtSavedModelFactory>>
139 CreateDefaultTfrtSavedModelFactory(const TfrtSavedModelConfig& config) {
140  TF_ASSIGN_OR_RETURN(auto batcher, CreateBatchSchedulerFromConfig(config));
141  TF_ASSIGN_OR_RETURN(auto thread_pool_factory,
142  CreateThreadPoolFactoryFromConfig(config));
143 
144  return std::make_unique<TfrtSavedModelFactory>(
145  config, batcher, std::move(thread_pool_factory));
146 }
147 
148 } // namespace
149 
150 TfrtSavedModelFactory::~TfrtSavedModelFactory() = default;
151 
153  const TfrtSavedModelConfig& config,
154  std::unique_ptr<TfrtSavedModelFactory>* factory) {
155  auto create_fn = GetGlobalTfrtSavedModelFactoryRegistry().Get();
156  if (!create_fn) {
157  return absl::InternalError(
158  "Missing create_fn for the TfrtSavedModelFactory.");
159  }
160  TF_ASSIGN_OR_RETURN(*factory, create_fn(config));
161  return absl::OkStatus();
162 }
163 
165  const std::string& path, ResourceAllocation* estimate) const {
166  return EstimateResourceFromPath(
167  path, config().resource_estimation_uses_validation_result(), estimate);
168 }
169 
171  const Loader::Metadata& metadata, const std::string& path,
172  std::unique_ptr<tfrt::SavedModel>* saved_model) {
173  std::unordered_set<std::string> saved_model_tags(
174  config().saved_model_tags().begin(), config().saved_model_tags().end());
175  // Defaults to loading the meta graph def corresponding to the `serve` tag if
176  // no `saved_model_tags` are specified.
177  if (saved_model_tags.empty()) {
178  saved_model_tags.insert(kSavedModelTagServe);
179  }
180 
181  LOG(INFO) << "Creating TFRT SavedModel for path: " << path
182  << " with config: " << config_.DebugString();
183  auto* runtime = tensorflow::tfrt_stub::GetGlobalRuntime();
184  tfrt::SavedModel::Options options(runtime);
185 
186  // Register the right type of custom backend currently only requires setting
187  // `use_ifrt`.
188  options.graph_execution_options.use_ifrt = config_.tfrt_use_ifrt();
189  TF_RETURN_IF_ERROR(RegisterCustomBackend(options.graph_execution_options));
190 
191  // TODO(b/326069213): Consider using arena allocation when loading a
192  // MetaGraphDef.
193  tensorflow::MetaGraphDef meta_graph_def;
194  TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel(
195  std::string(path), saved_model_tags, &meta_graph_def));
196  if (auto& graph_rewriter = tensorflow::serving::GraphRewriter::GetGlobal();
197  graph_rewriter.IsRegistered()) {
198  TF_RETURN_IF_ERROR(graph_rewriter.Get()(&meta_graph_def));
199  }
200  options.enable_lazy_loading =
201  meta_graph_def.signature_def_size() > config_.lazy_init_threshold();
202  options.maybe_load_from_mla = config_.maybe_load_from_mla();
203  options.lazy_loading_use_graph_executor =
204  config_.lazy_loading_use_graph_executor();
205  auto& compile_options = options.graph_execution_options.compile_options;
206  compile_options.enable_grappler = config_.enable_grappler();
207  compile_options.graph_options = config_.graph_options();
208  if (config_.enable_saved_model_config()) {
209  TF_RETURN_IF_ERROR(LoadSavedModelConfig(
210  path, options.graph_execution_options.compile_options.graph_options,
211  options.graph_execution_options.runtime_config));
212  }
213  if (config_.target_tpu()) {
214  compile_options.device_target = TfrtDeviceInfraTarget::kTpurt;
215  } else if (config_.enable_tfrt_gpu()) {
216  compile_options.device_target = TfrtDeviceInfraTarget::kGpu;
217  } else {
218  compile_options.device_target = TfrtDeviceInfraTarget::kCpu;
219  }
220  compile_options.hoist_invariant_ops = config_.hoist_invariant_ops();
221  compile_options.sink_in_invariant_ops = config_.sink_in_invariant_ops();
222  compile_options.cost_threshold = config_.stream_merge_threshold();
223  compile_options.merge_inter_dependent_streams =
224  config_.merge_inter_dependent_streams();
225  compile_options.tpu_move_resource_gather_to_host =
226  config_.tpu_move_resource_gather_to_host();
227  compile_options.tpu_gather_table_width_threshold_bytes =
228  config_.tpu_gather_table_width_threshold_bytes();
229  compile_options.tpu_fuse_ops = config_.use_fused_tpu_op();
230  compile_options.enable_while_parallel_iterations =
231  config_.enable_while_parallel_iterations();
232  compile_options.use_tpu_host_allocator_for_inputs =
233  config_.use_tpu_host_allocator_for_inputs();
234  compile_options.tpu_allow_unpadded_batch =
235  ToTpuAllowUnpaddedBatch(config_.tpu_unpadded_batch_mode());
236  compile_options.use_gpu_compile_and_execute_op =
237  config_.tfrt_use_fused_gpu_op();
238  compile_options.min_num_batch_threads = config_.tfrt_min_num_batch_threads();
239  compile_options.min_max_enqueued_batches =
240  config_.tfrt_min_max_enqueued_batches();
241  compile_options.batch_padding_policy = config_.batch_padding_policy();
242 
243  options.graph_execution_options.run_placer_grappler_on_functions =
244  config_.run_placer_grappler_on_functions();
245  options.graph_execution_options.enable_tfrt_gpu = config_.enable_tfrt_gpu();
246  options.graph_execution_options.tfrt_gpu_parallelism =
247  config_.tfrt_gpu_parallelism();
248  options.graph_execution_options.gpu_system_memory_size_in_mb =
249  config_.gpu_system_memory_size_in_mb();
250  options.graph_execution_options.enable_grappler_function_optimizer =
251  config_.enable_grappler_function_optimizer();
252  options.graph_execution_options.enable_online_cost_analysis =
253  config_.enable_online_cost_analysis();
254  options.graph_execution_options.enable_mlrt = config_.enable_mlrt();
255  options.graph_execution_options.model_metadata.set_name(
256  metadata.servable_id.name);
257  options.graph_execution_options.model_metadata.set_version(
258  metadata.servable_id.version);
259 
260  TF_ASSIGN_OR_RETURN(*saved_model,
261  tfrt::SavedModelImpl::LoadSavedModel(
262  std::move(options), std::move(meta_graph_def), path));
263  if (config_.has_batching_parameters() &&
264  config_.batching_parameters().ByteSizeLong() != 0) {
265  absl::optional<BatchingParameters> batching_params;
266  TF_RETURN_IF_ERROR(GetPerModelBatchingParams(
267  path, config_.batching_parameters(),
268  config_.enable_per_model_batching_params(), &batching_params));
269  if (batching_params.has_value()) {
270  LOG(INFO) << "Wrapping TFRT SavedModel for batching with params: "
271  << batching_params.value().DebugString();
272  return WrapSavedModelForBatching(
273  batching_params.value(), batch_scheduler_,
274  (*saved_model)->GetFunctionNames(), saved_model);
275  }
276  }
277  return absl::OkStatus();
278 }
279 
281  const Loader::Metadata& metadata, const std::string& path,
282  std::unique_ptr<Servable>* servable) {
283  TF_ASSIGN_OR_RETURN(auto override_servable, OverrideServable(metadata, path));
284  if (override_servable) {
285  *servable = std::move(override_servable);
286  return absl::OkStatus();
287  }
288 
289  std::unique_ptr<tfrt_stub::SavedModel> saved_model;
290  TF_RETURN_IF_ERROR(
291  CreateTfrtSavedModelWithMetadata(metadata, path, &saved_model));
292 
293  MaybePublishMLMDStreamz(path, metadata.servable_id.name,
294  metadata.servable_id.version);
295  TF_ASSIGN_OR_RETURN(auto saved_model_config,
296  LoadSavedModelConfigOrDefault(path));
297 
298  *servable = std::make_unique<TfrtSavedModelServable>(
299  metadata.servable_id.name, metadata.servable_id.version, config_,
300  saved_model_config, std::move(saved_model), thread_pool_factory_.get(),
301  recorder_creator_);
302  TfrtSavedModelServable* tfrt_servable =
303  down_cast<TfrtSavedModelServable*>(servable->get());
304 
305  if (config().enable_model_warmup()) {
306  auto* warmup_options = mutable_config().mutable_model_warmup_options();
307  warmup_options->set_model_name(metadata.servable_id.name);
308  warmup_options->set_model_version(metadata.servable_id.version);
309  TF_RETURN_IF_ERROR(RunSavedModelWarmup(
310  *warmup_options, path, config().lazy_init_threshold(),
311  config().skip_warmup_requests_if_initialized(),
312  &tfrt_servable->saved_model()));
313  if (config().freeze_after_init()) {
314  TF_RETURN_IF_ERROR(Freeze(tfrt_servable->saved_model()));
315  }
316  }
317 
318  return absl::OkStatus();
319 }
320 
321 TfrtSavedModelFactory::TfrtSavedModelFactory(
322  const TfrtSavedModelConfig& config,
323  std::shared_ptr<Batcher> batch_scheduler,
324  std::unique_ptr<ThreadPoolFactory> thread_pool_factory,
325  std::function<std::unique_ptr<RequestRecorder>(TfrtSavedModelServable&)>
326  recorder_creator)
327  : config_(config),
328  batch_scheduler_(batch_scheduler),
329  thread_pool_factory_(std::move(thread_pool_factory)),
330  recorder_creator_(std::move(recorder_creator)) {}
331 
332 TfrtSavedModelFactoryRegistry::TfrtSavedModelFactoryRegistry() {
333  factory_create_fn_ = [](const TfrtSavedModelConfig& config) {
334  return CreateDefaultTfrtSavedModelFactory(config);
335  };
336 }
337 
338 absl::string_view TfrtSavedModelFactory::GetServingResourceType() const {
339  if (std::any_of(config_.saved_model_tags().begin(),
340  config_.saved_model_tags().end(),
341  [](const auto& tag) { return tag == kSavedModelTagTpu; })) {
342  return device_types::kTpu;
343  }
344  if (std::any_of(config_.saved_model_tags().begin(),
345  config_.saved_model_tags().end(),
346  [](const auto& tag) { return tag == kSavedModelTagGpu; })) {
347  return device_types::kGpu;
348  }
349  return device_types::kMain;
350 }
351 
352 TfrtSavedModelFactoryRegistry& GetGlobalTfrtSavedModelFactoryRegistry() {
353  static auto* const registry = new TfrtSavedModelFactoryRegistry;
354  return *registry;
355 }
356 
357 absl::StatusOr<std::shared_ptr<TfrtSavedModelFactory::Batcher>>
358 CreateBatchSchedulerFromConfig(const TfrtSavedModelConfig& config) {
359  std::shared_ptr<Batcher> batcher;
360  if (config.has_batching_parameters() &&
361  config.batching_parameters().ByteSizeLong() != 0) {
362  TF_RETURN_IF_ERROR(
363  CreateBatchScheduler(config.batching_parameters(), &batcher));
364  }
365  return batcher;
366 }
367 
368 absl::StatusOr<std::unique_ptr<ThreadPoolFactory>>
369 CreateThreadPoolFactoryFromConfig(const TfrtSavedModelConfig& config) {
370  const auto& thread_pool_factory_config_filepath =
371  config.thread_pool_factory_config_filepath();
372  std::unique_ptr<ThreadPoolFactory> thread_pool_factory;
373  if (!thread_pool_factory_config_filepath.empty()) {
374  ThreadPoolFactoryConfig thread_pool_factory_config;
375  TF_RETURN_IF_ERROR(tsl::ReadTextProto(tsl::Env::Default(),
376  thread_pool_factory_config_filepath,
377  &thread_pool_factory_config));
378  TF_RETURN_IF_ERROR(ThreadPoolFactoryRegistry::CreateFromAny(
379  thread_pool_factory_config.thread_pool_factory_config(),
380  &thread_pool_factory));
381  }
382  return thread_pool_factory;
383 }
384 
385 } // namespace serving
386 } // namespace tensorflow
absl::Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)
virtual absl::Status CreateTfrtSavedModelWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< Servable > *servable)
The metadata consists of the ServableId.
Definition: loader.h:94