16 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "absl/status/statusor.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/substitute.h"
25 #include "google/protobuf/text_format.h"
26 #include "xla/tsl/lib/core/status_test_util.h"
27 #include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29 #include "tensorflow_serving/servables/tensorflow/remote_op_config_rewriter.pb.h"
30 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
31 #include "tensorflow_serving/test_util/test_util.h"
33 namespace tensorflow {
37 const char kTestSavedModelWithoutSavedModelConfigPath[] =
38 "servables/tensorflow/testdata/"
39 "saved_model_half_plus_two_cpu/00000123";
41 const char kTestSavedModelWithSavedModelConfigPath[] =
42 "servables/tensorflow/testdata/"
43 "saved_model_half_plus_two_cpu_with_saved_model_config/00000123";
45 const char kTestSavedModelWithEmptySavedModelConfigPath[] =
46 "servables/tensorflow/testdata/"
47 "saved_model_half_plus_two_cpu_with_empty_saved_model_config/00000123";
49 using test_util::EqualsProto;
51 TEST(LoadSavedModeConfigTest, MissingSavedModelConfig) {
52 const std::string export_dir =
53 test_util::TestSrcDirPath(kTestSavedModelWithoutSavedModelConfigPath);
55 absl::StatusOr<SavedModelConfig> saved_model_config =
56 LoadSavedModelConfigOrDefault(export_dir);
57 TF_ASSERT_OK(saved_model_config.status());
58 EXPECT_THAT(saved_model_config.value(), EqualsProto(
""));
61 TEST(LoadSavedModelConfigTest, EmptySavedModelConfig) {
62 const std::string export_dir =
63 test_util::TestSrcDirPath(kTestSavedModelWithEmptySavedModelConfigPath);
65 absl::StatusOr<SavedModelConfig> saved_model_config =
66 LoadSavedModelConfigOrDefault(export_dir);
68 TF_ASSERT_OK(saved_model_config.status());
69 EXPECT_THAT(saved_model_config.value(), EqualsProto(
""));
72 TEST(LoadSavedModelConfigTest, SavedModelConfig) {
73 const std::string export_dir =
74 test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
75 absl::StatusOr<SavedModelConfig> saved_model_config =
76 LoadSavedModelConfigOrDefault(export_dir);
78 TF_ASSERT_OK(saved_model_config.status());
80 SavedModelConfig expected_config;
81 bool result = ::google::protobuf::TextFormat::ParseFromString(
84 remote_op_remap_config {
86 key: "placeholder_model_name"
89 target_address_remap {
90 key: "placeholder_model_name"
91 value: "target_address"
94 batch_op_rewriter_config {
96 key: "placeholder_model_name"
98 batch_timeout_micros: 100
99 allowed_batch_sizes: [ 2, 4, 8 ]
104 tfrt_runtime_config {
106 type_url: "type.googleapis.com/tensorflow.tfrt_stub.TestConfig1"
107 value: "\n\rtest config 1"
115 EXPECT_THAT(saved_model_config.value(), EqualsProto(expected_config));
118 TEST(UpdateRewriterConfigTest, AddOptimizers) {
119 const std::string export_dir =
120 test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
121 absl::StatusOr<SavedModelConfig> saved_model_config =
122 LoadSavedModelConfigOrDefault(export_dir);
124 TF_ASSERT_OK(saved_model_config.status());
125 tensorflow::RewriterConfig rewrite_options;
127 UpdateRewriterConfig(saved_model_config.value().session_overrides(),
130 EXPECT_THAT(rewrite_options.custom_optimizers(),
131 ::testing::UnorderedElementsAre(
132 EqualsProto(absl::Substitute(
139 kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
140 absl::Base64Escape(saved_model_config.value()
142 .remote_op_remap_config()
143 .SerializeAsString()))),
144 EqualsProto(absl::Substitute(
151 kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
152 absl::Base64Escape(saved_model_config.value()
154 .batch_op_rewriter_config()
155 .SerializeAsString())))));
158 TEST(UpdateRewriterConfigTest, ReplaceOptimizers) {
159 const std::string export_dir =
160 test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
161 absl::StatusOr<SavedModelConfig> saved_model_config =
162 LoadSavedModelConfigOrDefault(export_dir);
164 TF_ASSERT_OK(saved_model_config.status());
165 tensorflow::RewriterConfig rewrite_options;
166 bool result = ::google::protobuf::TextFormat::ParseFromString(
169 name: "remote_op_config_rewrite"
171 key: "remote_op_rewrite_config"
172 value { s: "whatever placeholder value" }
176 name: "batch_op_rewriter"
178 key: "batch_op_rewrite_config"
179 value { s: "whatever placeholder value" }
185 UpdateRewriterConfig(saved_model_config.value().session_overrides(),
189 EXPECT_THAT(rewrite_options.custom_optimizers(),
190 ::testing::UnorderedElementsAre(
191 EqualsProto(absl::Substitute(
198 kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
199 absl::Base64Escape(saved_model_config.value()
201 .remote_op_remap_config()
202 .SerializeAsString()))),
203 EqualsProto(absl::Substitute(
210 kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
211 absl::Base64Escape(saved_model_config.value()
213 .batch_op_rewriter_config()
214 .SerializeAsString())))));