15 #include "tensorflow_serving/servables/tensorflow/saved_model_config.h"
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "absl/strings/escaping.h"
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/cc/saved_model/constants.h"
24 #include "xla/tsl/lib/core/status_test_util.h"
25 #include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 #include "tensorflow/core/tfrt/graph_executor/config.h"
29 #include "tensorflow/core/tfrt/graph_executor/test_config.pb.h"
30 #include "tsl/platform/env.h"
31 #include "tsl/platform/path.h"
32 #include "tsl/platform/status.h"
33 #include "tensorflow_serving/servables/tensorflow/remote_op_config_rewriter.pb.h"
34 #include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
35 #include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
36 #include "tensorflow_serving/test_util/test_util.h"
38 namespace tensorflow {
41 const char kTestSavedModelWithoutSavedModelConfigPath[] =
42 "servables/tensorflow/testdata/"
43 "saved_model_half_plus_two_cpu/00000123";
45 const char kTestSavedModelWithModelConfigPath[] =
46 "servables/tensorflow/testdata/"
47 "saved_model_half_plus_two_cpu_with_saved_model_config/00000123";
49 const char kTestSavedModelWithEmptyModelConfigPath[] =
50 "servables/tensorflow/testdata/"
51 "saved_model_half_plus_two_cpu_with_empty_saved_model_config/00000123";
53 using test_util::EqualsProto;
55 TEST(SavedModeConfigTest, MissingSavedModelConfig) {
56 const std::string export_dir =
57 test_util::TestSrcDirPath(kTestSavedModelWithoutSavedModelConfigPath);
58 tensorflow::GraphOptions graph_options;
59 tensorflow::tfrt_stub::RuntimeConfig runtime_config;
61 TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
63 auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
64 EXPECT_EQ(custom_optimizers.size(), 0);
65 EXPECT_EQ(runtime_config.ToProto().config_size(), 0);
68 TEST(ModelRuntimeConfigTest, EmptyModelConfig) {
69 const std::string export_dir =
70 test_util::TestSrcDirPath(kTestSavedModelWithEmptyModelConfigPath);
71 tensorflow::GraphOptions graph_options;
72 tensorflow::tfrt_stub::RuntimeConfig runtime_config;
74 TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
76 auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
77 EXPECT_EQ(custom_optimizers.size(), 0);
78 EXPECT_EQ(runtime_config.ToProto().config_size(), 0);
81 TEST(ModelRuntimeConfigTest, OverwriteRuntimeConfig) {
82 const std::string export_dir =
83 test_util::TestSrcDirPath(kTestSavedModelWithModelConfigPath);
84 tensorflow::GraphOptions graph_options;
85 tensorflow::tfrt_stub::RuntimeConfig runtime_config;
87 tensorflow::tfrt_stub::TestConfig1 old_test_config1;
88 old_test_config1.set_tag(
"whatever tag");
89 TF_ASSERT_OK(runtime_config.Add(old_test_config1));
91 TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
93 auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
94 EXPECT_EQ(custom_optimizers.size(), 2);
96 runtime_config.ToProto(), EqualsProto(R
"pb(
98 type_url: "type.googleapis.com/tensorflow.tfrt_stub.TestConfig1"
99 value: "\n\rtest config 1"
104 TEST(ModelRuntimeConfigTest, ModelConfig) {
105 const std::string export_dir =
106 test_util::TestSrcDirPath(kTestSavedModelWithModelConfigPath);
107 SavedModelConfig model_config;
110 TF_ASSERT_OK(tsl::ReadFileToString(
112 test_util::TestSrcDirPath(tsl::io::JoinPath(
113 kTestSavedModelWithModelConfigPath, kSavedModelAssetsExtraDirectory,
114 kSavedModelConfigPath)),
117 EXPECT_TRUE(model_config.ParseFromString(content));
120 tensorflow::GraphOptions graph_options;
121 tensorflow::tfrt_stub::RuntimeConfig runtime_config;
123 TF_ASSERT_OK(LoadSavedModelConfig(export_dir, graph_options, runtime_config));
125 auto& custom_optimizers = graph_options.rewrite_options().custom_optimizers();
126 EXPECT_EQ(custom_optimizers.size(), 2);
128 EXPECT_THAT(custom_optimizers,
129 ::testing::UnorderedElementsAre(
130 EqualsProto(absl::Substitute(
137 kRemoteOpConfigRewriter, kRemoteOpRewriteConfigParamKey,
138 absl::Base64Escape(model_config.session_overrides()
139 .remote_op_remap_config()
140 .SerializeAsString()))),
141 EqualsProto(absl::Substitute(
148 kBatchOpRewriter, kBatchOpRewriteConfigParamKey,
149 absl::Base64Escape(model_config.session_overrides()
150 .batch_op_rewriter_config()
151 .SerializeAsString())))));