TensorFlow Serving C++ API Documentation
saved_model_config_test.cc
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_serving/servables/tensorflow/saved_model_config.h"
16 
17 #include <string>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/escaping.h"
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/cc/saved_model/constants.h"
24 #include "xla/tsl/lib/core/status_test_util.h"
25 #include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 #include "tensorflow/core/tfrt/graph_executor/config.h"
29 #include "tensorflow/core/tfrt/graph_executor/test_config.pb.h"
30 #include "tsl/platform/env.h"
31 #include "tsl/platform/path.h"
32 #include "tsl/platform/status.h"
33 #include "tensorflow_serving/servables/tensorflow/remote_op_config_rewriter.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
36 #include "tensorflow_serving/test_util/test_util.h"
37 
38 namespace tensorflow {
39 namespace serving {
40 namespace {
41 const char kTestSavedModelWithoutSavedModelConfigPath[] =
42  "servables/tensorflow/testdata/"
43  "saved_model_half_plus_two_cpu/00000123";
44 
45 const char kTestSavedModelWithModelConfigPath[] =
46  "servables/tensorflow/testdata/"
47  "saved_model_half_plus_two_cpu_with_saved_model_config/00000123";
48 
49 const char kTestSavedModelWithEmptyModelConfigPath[] =
50  "servables/tensorflow/testdata/"
51  "saved_model_half_plus_two_cpu_with_empty_saved_model_config/00000123";
52 
53 using test_util::EqualsProto;
54 
55 TEST(SavedModeConfigTest, MissingSavedModelConfig) {
56  const std::string export_dir =
57  test_util::TestSrcDirPath(kTestSavedModelWithoutSavedModelConfigPath);
58  tensorflow::GraphOptions graph_options;
59  tensorflow::tfrt_stub::RuntimeConfig runtime_config;
60 
61  TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
62 
63  auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
64  EXPECT_EQ(custom_optimizers.size(), 0);
65  EXPECT_EQ(runtime_config.ToProto().config_size(), 0);
66 }
67 
68 TEST(ModelRuntimeConfigTest, EmptyModelConfig) {
69  const std::string export_dir =
70  test_util::TestSrcDirPath(kTestSavedModelWithEmptyModelConfigPath);
71  tensorflow::GraphOptions graph_options;
72  tensorflow::tfrt_stub::RuntimeConfig runtime_config;
73 
74  TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
75 
76  auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
77  EXPECT_EQ(custom_optimizers.size(), 0);
78  EXPECT_EQ(runtime_config.ToProto().config_size(), 0);
79 }
80 
81 TEST(ModelRuntimeConfigTest, OverwriteRuntimeConfig) {
82  const std::string export_dir =
83  test_util::TestSrcDirPath(kTestSavedModelWithModelConfigPath);
84  tensorflow::GraphOptions graph_options;
85  tensorflow::tfrt_stub::RuntimeConfig runtime_config;
86 
87  tensorflow::tfrt_stub::TestConfig1 old_test_config1;
88  old_test_config1.set_tag("whatever tag");
89  TF_ASSERT_OK(runtime_config.Add(old_test_config1));
90 
91  TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
92 
93  auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
94  EXPECT_EQ(custom_optimizers.size(), 2);
95  EXPECT_THAT(
96  runtime_config.ToProto(), EqualsProto(R"pb(
97  config {
98  type_url: "type.googleapis.com/tensorflow.tfrt_stub.TestConfig1"
99  value: "\n\rtest config 1"
100  }
101  )pb"));
102 }
103 
104 TEST(ModelRuntimeConfigTest, ModelConfig) {
105  const std::string export_dir =
106  test_util::TestSrcDirPath(kTestSavedModelWithModelConfigPath);
107  SavedModelConfig model_config;
108  {
109  std::string content;
110  TF_ASSERT_OK(tsl::ReadFileToString(
111  tsl::Env::Default(),
112  test_util::TestSrcDirPath(tsl::io::JoinPath(
113  kTestSavedModelWithModelConfigPath, kSavedModelAssetsExtraDirectory,
114  kSavedModelConfigPath)),
115  &content));
116 
117  EXPECT_TRUE(model_config.ParseFromString(content));
118  }
119 
120  tensorflow::GraphOptions graph_options;
121  tensorflow::tfrt_stub::RuntimeConfig runtime_config;
122 
123  TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
124 
125  auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
126  EXPECT_EQ(custom_optimizers.size(), 2);
127 
128  EXPECT_THAT(custom_optimizers,
129  ::testing::UnorderedElementsAre(
130  EqualsProto(absl::Substitute(
131  R"pb(
132  name: "$0"
133  parameter_map {
134  key: "$1"
135  value { s: "$2" }
136  })pb",
137  kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
138  absl::Base64Escape(model_config.session_overrides()
139  .remote_op_remap_config()
140  .SerializeAsString()))),
141  EqualsProto(absl::Substitute(
142  R"pb(
143  name: "$0"
144  parameter_map {
145  key: "$1"
146  value { s: "$2" }
147  })pb",
148  kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
149  absl::Base64Escape(model_config.session_overrides()
150  .batch_op_rewriter_config()
151  .SerializeAsString())))));
152 }
153 
154 } // namespace
155 } // namespace serving
156 } // namespace tensorflow