TensorFlow Serving C++ API Documentation
saved_model_config_util_test.cc
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
17 
18 #include <string>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "absl/status/statusor.h"
23 #include "absl/strings/escaping.h"
24 #include "absl/strings/substitute.h"
25 #include "google/protobuf/text_format.h"
26 #include "xla/tsl/lib/core/status_test_util.h"
27 #include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29 #include "tensorflow_serving/servables/tensorflow/remote_op_config_rewriter.pb.h"
30 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
31 #include "tensorflow_serving/test_util/test_util.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36 
37 const char kTestSavedModelWithoutSavedModelConfigPath[] =
38  "servables/tensorflow/testdata/"
39  "saved_model_half_plus_two_cpu/00000123";
40 
41 const char kTestSavedModelWithSavedModelConfigPath[] =
42  "servables/tensorflow/testdata/"
43  "saved_model_half_plus_two_cpu_with_saved_model_config/00000123";
44 
45 const char kTestSavedModelWithEmptySavedModelConfigPath[] =
46  "servables/tensorflow/testdata/"
47  "saved_model_half_plus_two_cpu_with_empty_saved_model_config/00000123";
48 
49 using test_util::EqualsProto;
50 
51 TEST(LoadSavedModeConfigTest, MissingSavedModelConfig) {
52  const std::string export_dir =
53  test_util::TestSrcDirPath(kTestSavedModelWithoutSavedModelConfigPath);
54 
55  absl::StatusOr<SavedModelConfig> saved_model_config =
56  LoadSavedModelConfigOrDefault(export_dir);
57  TF_ASSERT_OK(saved_model_config.status());
58  EXPECT_THAT(saved_model_config.value(), EqualsProto(""));
59 }
60 
61 TEST(LoadSavedModelConfigTest, EmptySavedModelConfig) {
62  const std::string export_dir =
63  test_util::TestSrcDirPath(kTestSavedModelWithEmptySavedModelConfigPath);
64 
65  absl::StatusOr<SavedModelConfig> saved_model_config =
66  LoadSavedModelConfigOrDefault(export_dir);
67 
68  TF_ASSERT_OK(saved_model_config.status());
69  EXPECT_THAT(saved_model_config.value(), EqualsProto(""));
70 }
71 
72 TEST(LoadSavedModelConfigTest, SavedModelConfig) {
73  const std::string export_dir =
74  test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
75  absl::StatusOr<SavedModelConfig> saved_model_config =
76  LoadSavedModelConfigOrDefault(export_dir);
77 
78  TF_ASSERT_OK(saved_model_config.status());
79 
80  SavedModelConfig expected_config;
81  bool result = ::google::protobuf::TextFormat::ParseFromString(
82  R"pb(
83  session_overrides {
84  remote_op_remap_config {
85  model_name_remap {
86  key: "placeholder_model_name"
87  value: "model_name"
88  }
89  target_address_remap {
90  key: "placeholder_model_name"
91  value: "target_address"
92  }
93  }
94  batch_op_rewriter_config {
95  batch_options {
96  key: "placeholder_model_name"
97  value: {
98  batch_timeout_micros: 100
99  allowed_batch_sizes: [ 2, 4, 8 ]
100  }
101  }
102  }
103  }
104  tfrt_runtime_config {
105  config {
106  type_url: "type.googleapis.com/tensorflow.tfrt_stub.TestConfig1"
107  value: "\n\rtest config 1"
108  }
109  }
110  critical: true
111  )pb",
112  &expected_config);
113 
114  EXPECT_TRUE(result);
115  EXPECT_THAT(saved_model_config.value(), EqualsProto(expected_config));
116 }
117 
118 TEST(UpdateRewriterConfigTest, AddOptimizers) {
119  const std::string export_dir =
120  test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
121  absl::StatusOr<SavedModelConfig> saved_model_config =
122  LoadSavedModelConfigOrDefault(export_dir);
123 
124  TF_ASSERT_OK(saved_model_config.status());
125  tensorflow::RewriterConfig rewrite_options;
126 
127  UpdateRewriterConfig(saved_model_config.value().session_overrides(),
128  &rewrite_options);
129 
130  EXPECT_THAT(rewrite_options.custom_optimizers(),
131  ::testing::UnorderedElementsAre(
132  EqualsProto(absl::Substitute(
133  R"pb(
134  name: "$0"
135  parameter_map {
136  key: "$1"
137  value { s: "$2" }
138  })pb",
139  kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
140  absl::Base64Escape(saved_model_config.value()
141  .session_overrides()
142  .remote_op_remap_config()
143  .SerializeAsString()))),
144  EqualsProto(absl::Substitute(
145  R"pb(
146  name: "$0"
147  parameter_map {
148  key: "$1"
149  value { s: "$2" }
150  })pb",
151  kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
152  absl::Base64Escape(saved_model_config.value()
153  .session_overrides()
154  .batch_op_rewriter_config()
155  .SerializeAsString())))));
156 }
157 
158 TEST(UpdateRewriterConfigTest, ReplaceOptimizers) {
159  const std::string export_dir =
160  test_util::TestSrcDirPath(kTestSavedModelWithSavedModelConfigPath);
161  absl::StatusOr<SavedModelConfig> saved_model_config =
162  LoadSavedModelConfigOrDefault(export_dir);
163 
164  TF_ASSERT_OK(saved_model_config.status());
165  tensorflow::RewriterConfig rewrite_options;
166  bool result = ::google::protobuf::TextFormat::ParseFromString(
167  R"pb(
168  custom_optimizers {
169  name: "remote_op_config_rewrite"
170  parameter_map {
171  key: "remote_op_rewrite_config"
172  value { s: "whatever placeholder value" }
173  }
174  }
175  custom_optimizers {
176  name: "batch_op_rewriter"
177  parameter_map {
178  key: "batch_op_rewrite_config"
179  value { s: "whatever placeholder value" }
180  }
181  }
182  )pb",
183  &rewrite_options);
184 
185  UpdateRewriterConfig(saved_model_config.value().session_overrides(),
186  &rewrite_options);
187 
188  EXPECT_TRUE(result);
189  EXPECT_THAT(rewrite_options.custom_optimizers(),
190  ::testing::UnorderedElementsAre(
191  EqualsProto(absl::Substitute(
192  R"pb(
193  name: "$0"
194  parameter_map {
195  key: "$1"
196  value { s: "$2" }
197  })pb",
198  kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
199  absl::Base64Escape(saved_model_config.value()
200  .session_overrides()
201  .remote_op_remap_config()
202  .SerializeAsString()))),
203  EqualsProto(absl::Substitute(
204  R"pb(
205  name: "$0"
206  parameter_map {
207  key: "$1"
208  value { s: "$2" }
209  })pb",
210  kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
211  absl::Base64Escape(saved_model_config.value()
212  .session_overrides()
213  .batch_op_rewriter_config()
214  .SerializeAsString())))));
215 }
216 
217 } // namespace
218 } // namespace serving
219 } // namespace tensorflow