16 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
21 #include "absl/log/log.h"
22 #include "absl/status/statusor.h"
23 #include "absl/strings/escaping.h"
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 #include "tsl/platform/env.h"
29 #include "tsl/platform/errors.h"
30 #include "tsl/platform/file_system.h"
31 #include "tsl/platform/path.h"
32 #include "tsl/platform/stringpiece.h"
33 #include "tsl/platform/types.h"
34 #include "tensorflow_serving/servables/tensorflow/remote_op_config_rewriter.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
37 namespace tensorflow {
40 void AddOrReplaceOptimizer(
const std::string& custom_optimizer_name,
41 const std::string& parameter_key,
42 const std::string& parameter_value,
43 RewriterConfig* rewrite_options) {
44 google::protobuf::Map<std::string, AttrValue>* parameter_map =
nullptr;
45 for (
auto& custom_optimizer : *rewrite_options->mutable_custom_optimizers()) {
46 if (custom_optimizer.name() == custom_optimizer_name) {
47 parameter_map = custom_optimizer.mutable_parameter_map();
52 if (parameter_map ==
nullptr) {
53 auto* custom_optimizer = rewrite_options->add_custom_optimizers();
54 custom_optimizer->set_name(custom_optimizer_name);
55 parameter_map = custom_optimizer->mutable_parameter_map();
58 (*parameter_map)[parameter_key].set_s(absl::Base64Escape(parameter_value));
62 void UpdateRewriterConfig(
63 const tensorflow::serving::SessionOverrides& session_overrides,
64 tensorflow::RewriterConfig* rewrite_options) {
65 DCHECK(rewrite_options !=
nullptr);
68 if (session_overrides.has_remote_op_remap_config()) {
69 AddOrReplaceOptimizer(
70 kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
71 session_overrides.remote_op_remap_config().SerializeAsString(),
76 if (session_overrides.has_batch_op_rewriter_config()) {
77 AddOrReplaceOptimizer(
78 kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
79 session_overrides.batch_op_rewriter_config().SerializeAsString(),
84 absl::StatusOr<SavedModelConfig> LoadSavedModelConfigOrDefault(
85 const std::string& export_dir) {
86 const std::string saved_model_config_path = tsl::io::JoinPath(
87 export_dir, kSavedModelAssetsExtraDirectory, kSavedModelConfigPath);
88 SavedModelConfig saved_model_config;
89 if (!tsl::Env::Default()->FilesExist({saved_model_config_path},
nullptr)) {
91 return saved_model_config;
94 LOG(INFO) <<
"Loading model config from " << saved_model_config_path;
96 tsl::uint64 file_size = 0;
98 tsl::Env::Default()->GetFileSize(saved_model_config_path, &file_size));
99 content.resize(file_size);
101 std::unique_ptr<tsl::RandomAccessFile> file;
103 tsl::Env::Default()->NewRandomAccessFile(saved_model_config_path, &file));
105 absl::string_view result;
106 TF_RETURN_IF_ERROR(file->Read(0, file_size, &result, &(content)[0]));
108 if (!saved_model_config.ParseFromString(content)) {
109 return tsl::errors::Internal(
"Unable to parse SavedModelConfig: ",
110 saved_model_config_path);
112 LOG(INFO) <<
"Finished loading model config from " << saved_model_config_path
113 <<
":" << saved_model_config.DebugString();
114 return saved_model_config;