16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.h"
25 #include "google/protobuf/wrappers.pb.h"
26 #include <gmock/gmock.h>
27 #include <gtest/gtest.h>
28 #include "tensorflow/cc/saved_model/constants.h"
29 #include "tensorflow/cc/saved_model/loader.h"
30 #include "tensorflow/cc/saved_model/tag_constants.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/io/path.h"
35 #include "tensorflow/core/protobuf/named_tensor.pb.h"
36 #include "tensorflow/core/public/session.h"
37 #include "tensorflow/core/public/version.h"
38 #include "tensorflow_serving/core/test_util/session_test_util.h"
39 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test.h"
40 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
43 namespace tensorflow {
47 enum class CreationType { kWithoutMetadata, kWithMetadata };
49 enum class ModelType { kTfModel, kTfLiteModel };
51 Loader::Metadata CreateMetadata() {
return {ServableId{
"name", 42}}; }
54 Status CreateBundleFromPath(
const CreationType creation_type,
55 const SessionBundleConfig& config,
57 std::unique_ptr<SavedModelBundle>* bundle) {
58 std::unique_ptr<SavedModelBundleFactory> factory;
60 auto config_with_session_hook = config;
61 config_with_session_hook.set_session_target(
62 test_util::kNewSessionHookSessionTargetPrefix);
63 test_util::SetNewSessionHook([&](
const SessionOptions& session_options) {
64 const bool enable_session_metadata =
65 creation_type == CreationType::kWithMetadata;
66 EXPECT_EQ(enable_session_metadata,
67 session_options.config.experimental().has_session_metadata());
68 if (enable_session_metadata) {
69 const auto& actual_session_metadata =
70 session_options.config.experimental().session_metadata();
71 const auto& expected_loader_metadata = CreateMetadata();
72 EXPECT_EQ(expected_loader_metadata.servable_id.name,
73 actual_session_metadata.name());
74 EXPECT_EQ(expected_loader_metadata.servable_id.version,
75 actual_session_metadata.version());
77 return absl::OkStatus();
80 switch (creation_type) {
81 case CreationType::kWithoutMetadata:
82 TF_RETURN_IF_ERROR(factory->CreateSavedModelBundle(path, bundle));
84 case CreationType::kWithMetadata:
85 TF_RETURN_IF_ERROR(factory->CreateSavedModelBundleWithMetadata(
86 CreateMetadata(), path, bundle));
89 return absl::OkStatus();
92 struct SavedModelBundleFactoryTestParam {
93 CreationType creation_type;
95 bool prefer_tflite_model;
99 class SavedModelBundleFactoryTest
100 :
public test_util::BundleFactoryTest,
101 public ::testing::WithParamInterface<SavedModelBundleFactoryTestParam> {
103 SavedModelBundleFactoryTest()
104 : test_util::BundleFactoryTest(
105 GetParam().model_type == ModelType::kTfModel
106 ? test_util::GetTestSavedModelPath()
107 : test_util::GetTestTfLiteModelPath()) {}
109 virtual ~SavedModelBundleFactoryTest() =
default;
112 Status CreateSession(
const SessionBundleConfig& config,
113 std::unique_ptr<Session>* session)
const override {
114 std::unique_ptr<SavedModelBundle> bundle;
115 TF_RETURN_IF_ERROR(CreateBundleFromPath(GetParam().creation_type, config,
116 export_dir_, &bundle));
117 *session = std::move(bundle->session);
118 return absl::OkStatus();
121 SessionBundleConfig GetSessionBundleConfig()
const override {
122 SessionBundleConfig config;
123 config.set_prefer_tflite_model(GetParam().prefer_tflite_model);
127 bool IsRunOptionsSupported()
const override {
129 return GetParam().prefer_tflite_model ==
false ||
130 GetParam().model_type != ModelType::kTfLiteModel;
133 bool ExpectCreateBundleFailure()
const override {
135 return GetParam().prefer_tflite_model ==
false &&
136 GetParam().model_type == ModelType::kTfLiteModel;
139 std::vector<string> GetModelFiles() {
140 switch (GetParam().model_type) {
141 case ModelType::kTfModel: {
142 const string& dir = test_util::GetTestSavedModelPath();
144 io::JoinPath(dir, kSavedModelAssetsDirectory,
"foo.txt"),
145 io::JoinPath(dir, kSavedModelFilenamePb),
146 io::JoinPath(dir, kSavedModelVariablesFilename,
147 "variables.data-00000-of-00001"),
148 io::JoinPath(dir, kSavedModelVariablesFilename,
"variables.index")};
150 case ModelType::kTfLiteModel: {
152 io::JoinPath(test_util::GetTestTfLiteModelPath(),
"model.tflite")};
160 INSTANTIATE_TEST_SUITE_P(
161 CreationType, SavedModelBundleFactoryTest,
163 SavedModelBundleFactoryTestParam{
164 CreationType::kWithoutMetadata, ModelType::kTfModel,
167 SavedModelBundleFactoryTestParam{
168 CreationType::kWithoutMetadata, ModelType::kTfModel,
171 SavedModelBundleFactoryTestParam{
172 CreationType::kWithoutMetadata, ModelType::kTfLiteModel,
175 SavedModelBundleFactoryTestParam{
176 CreationType::kWithoutMetadata, ModelType::kTfLiteModel,
179 SavedModelBundleFactoryTestParam{
180 CreationType::kWithMetadata, ModelType::kTfModel,
183 SavedModelBundleFactoryTestParam{
184 CreationType::kWithMetadata, ModelType::kTfModel,
187 SavedModelBundleFactoryTestParam{
188 CreationType::kWithMetadata, ModelType::kTfLiteModel,
191 SavedModelBundleFactoryTestParam{
192 CreationType::kWithMetadata, ModelType::kTfLiteModel,
196 TEST_P(SavedModelBundleFactoryTest, Basic) { TestBasic(); }
198 TEST_P(SavedModelBundleFactoryTest, RemoveUnusedFieldsFromMetaGraphDefault) {
199 SessionBundleConfig config = GetSessionBundleConfig();
200 *config.add_saved_model_tags() = kSavedModelTagServe;
201 std::unique_ptr<SavedModelBundle> bundle;
202 if (ExpectCreateBundleFailure()) {
203 EXPECT_FALSE(CreateBundleFromPath(GetParam().creation_type, config,
204 export_dir_, &bundle)
208 TF_ASSERT_OK(CreateBundleFromPath(GetParam().creation_type, config,
209 export_dir_, &bundle));
210 if (GetParam().prefer_tflite_model &&
211 (GetParam().model_type == ModelType::kTfLiteModel)) {
213 EXPECT_FALSE(bundle->meta_graph_def.has_graph_def());
215 EXPECT_TRUE(bundle->meta_graph_def.has_graph_def());
217 EXPECT_FALSE(bundle->meta_graph_def.signature_def().empty());
220 TEST_P(SavedModelBundleFactoryTest, RemoveUnusedFieldsFromMetaGraphEnabled) {
221 SessionBundleConfig config = GetSessionBundleConfig();
222 *config.add_saved_model_tags() = kSavedModelTagServe;
223 config.set_remove_unused_fields_from_bundle_metagraph(
true);
224 std::unique_ptr<SavedModelBundle> bundle;
225 if (ExpectCreateBundleFailure()) {
226 EXPECT_FALSE(CreateBundleFromPath(GetParam().creation_type, config,
227 export_dir_, &bundle)
231 TF_ASSERT_OK(CreateBundleFromPath(GetParam().creation_type, config,
232 export_dir_, &bundle));
233 EXPECT_FALSE(bundle->meta_graph_def.has_graph_def());
234 EXPECT_FALSE(bundle->meta_graph_def.signature_def().empty());
237 TEST_P(SavedModelBundleFactoryTest, Batching) {
241 TestBatching(test_util::CreateProto<BatchingParameters>(R
"(
242 max_batch_size { value: 4 }
243 enable_large_batch_splitting { value: False })"),
248 TestBatching(test_util::CreateProto<BatchingParameters>(R
"(
249 max_batch_size { value: 4 }
250 enable_large_batch_splitting { value: True }
251 max_execution_batch_size { value: 2 })"),
257 TEST_P(SavedModelBundleFactoryTest, PerModelBatchingParams) {
261 const string dst_dir = io::JoinPath(testing::TmpDir(),
"model");
262 test_util::CopyDirOrDie(export_dir_, dst_dir);
264 const string& per_model_params_pbtxt(R
"(
265 max_batch_size { value: 10 }
266 batch_timeout_micros { value: 100000000 })");
267 std::ofstream ofs(io::JoinPath(dst_dir, "batching_params.pbtxt"));
268 ofs << per_model_params_pbtxt;
270 export_dir_ = dst_dir;
272 const BatchingParameters& common_params =
273 test_util::CreateProto<BatchingParameters>(
274 R
"(max_batch_size { value: 4 })");
275 TestBatching(common_params, false,
277 TestBatching(common_params,
true,
281 TEST_P(SavedModelBundleFactoryTest, EstimateResourceRequirementWithGoodExport) {
282 const double kTotalFileSize = test_util::GetTotalFileSize(GetModelFiles());
283 TestEstimateResourceRequirementWithGoodExport<SavedModelBundleFactory>(
287 TEST_P(SavedModelBundleFactoryTest, RunOptions) { TestRunOptions(); }
289 TEST_P(SavedModelBundleFactoryTest, RunOptionsError) { TestRunOptionsError(); }
static Status Create(const SessionBundleConfig &config, std::unique_ptr< SavedModelBundleFactory > *factory)