TensorFlow Serving C++ API Documentation
saved_model_bundle_factory_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.h"
17 
18 #include <fstream>
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "google/protobuf/wrappers.pb.h"
26 #include <gmock/gmock.h>
27 #include <gtest/gtest.h>
28 #include "tensorflow/cc/saved_model/constants.h"
29 #include "tensorflow/cc/saved_model/loader.h"
30 #include "tensorflow/cc/saved_model/tag_constants.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/io/path.h"
35 #include "tensorflow/core/protobuf/named_tensor.pb.h"
36 #include "tensorflow/core/public/session.h"
37 #include "tensorflow/core/public/version.h"
38 #include "tensorflow_serving/core/test_util/session_test_util.h"
39 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test.h"
40 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
41 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 namespace {
46 
47 enum class CreationType { kWithoutMetadata, kWithMetadata };
48 
49 enum class ModelType { kTfModel, kTfLiteModel };
50 
51 Loader::Metadata CreateMetadata() { return {ServableId{"name", 42}}; }
52 
53 // Creates a new session based on the config and export path.
54 Status CreateBundleFromPath(const CreationType creation_type,
55  const SessionBundleConfig& config,
56  const string& path,
57  std::unique_ptr<SavedModelBundle>* bundle) {
58  std::unique_ptr<SavedModelBundleFactory> factory;
59  TF_RETURN_IF_ERROR(SavedModelBundleFactory::Create(config, &factory));
60  auto config_with_session_hook = config;
61  config_with_session_hook.set_session_target(
62  test_util::kNewSessionHookSessionTargetPrefix);
63  test_util::SetNewSessionHook([&](const SessionOptions& session_options) {
64  const bool enable_session_metadata =
65  creation_type == CreationType::kWithMetadata;
66  EXPECT_EQ(enable_session_metadata,
67  session_options.config.experimental().has_session_metadata());
68  if (enable_session_metadata) {
69  const auto& actual_session_metadata =
70  session_options.config.experimental().session_metadata();
71  const auto& expected_loader_metadata = CreateMetadata();
72  EXPECT_EQ(expected_loader_metadata.servable_id.name,
73  actual_session_metadata.name());
74  EXPECT_EQ(expected_loader_metadata.servable_id.version,
75  actual_session_metadata.version());
76  }
77  return absl::OkStatus();
78  });
79 
80  switch (creation_type) {
81  case CreationType::kWithoutMetadata:
82  TF_RETURN_IF_ERROR(factory->CreateSavedModelBundle(path, bundle));
83  break;
84  case CreationType::kWithMetadata:
85  TF_RETURN_IF_ERROR(factory->CreateSavedModelBundleWithMetadata(
86  CreateMetadata(), path, bundle));
87  break;
88  }
89  return absl::OkStatus();
90 }
91 
92 struct SavedModelBundleFactoryTestParam {
93  CreationType creation_type;
94  ModelType model_type;
95  bool prefer_tflite_model;
96 };
97 
98 // Tests SavedModelBundleFactory with native SavedModel.
99 class SavedModelBundleFactoryTest
100  : public test_util::BundleFactoryTest,
101  public ::testing::WithParamInterface<SavedModelBundleFactoryTestParam> {
102  public:
103  SavedModelBundleFactoryTest()
104  : test_util::BundleFactoryTest(
105  GetParam().model_type == ModelType::kTfModel
106  ? test_util::GetTestSavedModelPath()
107  : test_util::GetTestTfLiteModelPath()) {}
108 
109  virtual ~SavedModelBundleFactoryTest() = default;
110 
111  protected:
112  Status CreateSession(const SessionBundleConfig& config,
113  std::unique_ptr<Session>* session) const override {
114  std::unique_ptr<SavedModelBundle> bundle;
115  TF_RETURN_IF_ERROR(CreateBundleFromPath(GetParam().creation_type, config,
116  export_dir_, &bundle));
117  *session = std::move(bundle->session);
118  return absl::OkStatus();
119  }
120 
121  SessionBundleConfig GetSessionBundleConfig() const override {
122  SessionBundleConfig config;
123  config.set_prefer_tflite_model(GetParam().prefer_tflite_model);
124  return config;
125  }
126 
127  bool IsRunOptionsSupported() const override {
128  // Presently TensorFlow Lite sessions do NOT support RunOptions.
129  return GetParam().prefer_tflite_model == false ||
130  GetParam().model_type != ModelType::kTfLiteModel;
131  }
132 
133  bool ExpectCreateBundleFailure() const override {
134  // The test Tensorflow Lite model does not include saved_model artifacts
135  return GetParam().prefer_tflite_model == false &&
136  GetParam().model_type == ModelType::kTfLiteModel;
137  }
138 
139  std::vector<string> GetModelFiles() {
140  switch (GetParam().model_type) {
141  case ModelType::kTfModel: {
142  const string& dir = test_util::GetTestSavedModelPath();
143  return {
144  io::JoinPath(dir, kSavedModelAssetsDirectory, "foo.txt"),
145  io::JoinPath(dir, kSavedModelFilenamePb),
146  io::JoinPath(dir, kSavedModelVariablesFilename,
147  "variables.data-00000-of-00001"),
148  io::JoinPath(dir, kSavedModelVariablesFilename, "variables.index")};
149  }
150  case ModelType::kTfLiteModel: {
151  return {
152  io::JoinPath(test_util::GetTestTfLiteModelPath(), "model.tflite")};
153  }
154  default:
155  return {};
156  }
157  }
158 };
159 
160 INSTANTIATE_TEST_SUITE_P(
161  CreationType, SavedModelBundleFactoryTest,
162  ::testing::Values(
163  SavedModelBundleFactoryTestParam{
164  CreationType::kWithoutMetadata, ModelType::kTfModel,
165  true // prefer_tflite_model
166  },
167  SavedModelBundleFactoryTestParam{
168  CreationType::kWithoutMetadata, ModelType::kTfModel,
169  false // prefer_tflite_model
170  },
171  SavedModelBundleFactoryTestParam{
172  CreationType::kWithoutMetadata, ModelType::kTfLiteModel,
173  true // prefer_tflite_model
174  },
175  SavedModelBundleFactoryTestParam{
176  CreationType::kWithoutMetadata, ModelType::kTfLiteModel,
177  false // prefer_tflite_model
178  },
179  SavedModelBundleFactoryTestParam{
180  CreationType::kWithMetadata, ModelType::kTfModel,
181  true // prefer_tflite_model
182  },
183  SavedModelBundleFactoryTestParam{
184  CreationType::kWithMetadata, ModelType::kTfModel,
185  false // prefer_tflite_model
186  },
187  SavedModelBundleFactoryTestParam{
188  CreationType::kWithMetadata, ModelType::kTfLiteModel,
189  true // prefer_tflite_model
190  },
191  SavedModelBundleFactoryTestParam{
192  CreationType::kWithMetadata, ModelType::kTfLiteModel,
193  false // prefer_tflite_model
194  }));
195 
196 TEST_P(SavedModelBundleFactoryTest, Basic) { TestBasic(); }
197 
198 TEST_P(SavedModelBundleFactoryTest, RemoveUnusedFieldsFromMetaGraphDefault) {
199  SessionBundleConfig config = GetSessionBundleConfig();
200  *config.add_saved_model_tags() = kSavedModelTagServe;
201  std::unique_ptr<SavedModelBundle> bundle;
202  if (ExpectCreateBundleFailure()) {
203  EXPECT_FALSE(CreateBundleFromPath(GetParam().creation_type, config,
204  export_dir_, &bundle)
205  .ok());
206  return;
207  }
208  TF_ASSERT_OK(CreateBundleFromPath(GetParam().creation_type, config,
209  export_dir_, &bundle));
210  if (GetParam().prefer_tflite_model &&
211  (GetParam().model_type == ModelType::kTfLiteModel)) {
212  // TF Lite model never has a graph_def.
213  EXPECT_FALSE(bundle->meta_graph_def.has_graph_def());
214  } else {
215  EXPECT_TRUE(bundle->meta_graph_def.has_graph_def());
216  }
217  EXPECT_FALSE(bundle->meta_graph_def.signature_def().empty());
218 }
219 
220 TEST_P(SavedModelBundleFactoryTest, RemoveUnusedFieldsFromMetaGraphEnabled) {
221  SessionBundleConfig config = GetSessionBundleConfig();
222  *config.add_saved_model_tags() = kSavedModelTagServe;
223  config.set_remove_unused_fields_from_bundle_metagraph(true);
224  std::unique_ptr<SavedModelBundle> bundle;
225  if (ExpectCreateBundleFailure()) {
226  EXPECT_FALSE(CreateBundleFromPath(GetParam().creation_type, config,
227  export_dir_, &bundle)
228  .ok());
229  return;
230  }
231  TF_ASSERT_OK(CreateBundleFromPath(GetParam().creation_type, config,
232  export_dir_, &bundle));
233  EXPECT_FALSE(bundle->meta_graph_def.has_graph_def());
234  EXPECT_FALSE(bundle->meta_graph_def.signature_def().empty());
235 }
236 
237 TEST_P(SavedModelBundleFactoryTest, Batching) {
238  // Most test cases don't cover batching session code path so call
239  // 'TestBatching' twice with different options for batching test case, as
240  // opposed to parameterize test.
241  TestBatching(test_util::CreateProto<BatchingParameters>(R"(
242  max_batch_size { value: 4 }
243  enable_large_batch_splitting { value: False })"),
244  /*enable_per_model_batching_params=*/false,
245  /*input_request_batch_size=*/2,
246  /*batch_size=*/4);
247 
248  TestBatching(test_util::CreateProto<BatchingParameters>(R"(
249  max_batch_size { value: 4 }
250  enable_large_batch_splitting { value: True }
251  max_execution_batch_size { value: 2 })"),
252  /*enable_per_model_batching_params=*/false,
253  /*input_request_batch_size=*/3,
254  /*batch_size=*/2);
255 }
256 
257 TEST_P(SavedModelBundleFactoryTest, PerModelBatchingParams) {
258  //
259  // Copy SavedModel to temp (writable) location, and add batching params.
260  //
261  const string dst_dir = io::JoinPath(testing::TmpDir(), "model");
262  test_util::CopyDirOrDie(export_dir_, dst_dir);
263  // Note, timeout is set to high value to force batch formation.
264  const string& per_model_params_pbtxt(R"(
265  max_batch_size { value: 10 }
266  batch_timeout_micros { value: 100000000 })");
267  std::ofstream ofs(io::JoinPath(dst_dir, "batching_params.pbtxt"));
268  ofs << per_model_params_pbtxt;
269  ofs.close();
270  export_dir_ = dst_dir;
271 
272  const BatchingParameters& common_params =
273  test_util::CreateProto<BatchingParameters>(
274  R"(max_batch_size { value: 4 })");
275  TestBatching(common_params, /*enable_per_model_batching_params=*/false,
276  /*input_request_batch_size=*/2, /*batch_size=*/4);
277  TestBatching(common_params, /*enable_per_model_batching_params=*/true,
278  /*input_request_batch_size=*/2, /*batch_size=*/10);
279 }
280 
281 TEST_P(SavedModelBundleFactoryTest, EstimateResourceRequirementWithGoodExport) {
282  const double kTotalFileSize = test_util::GetTotalFileSize(GetModelFiles());
283  TestEstimateResourceRequirementWithGoodExport<SavedModelBundleFactory>(
284  kTotalFileSize);
285 }
286 
287 TEST_P(SavedModelBundleFactoryTest, RunOptions) { TestRunOptions(); }
288 
289 TEST_P(SavedModelBundleFactoryTest, RunOptionsError) { TestRunOptionsError(); }
290 
291 } // namespace
292 } // namespace serving
293 } // namespace tensorflow
static Status Create(const SessionBundleConfig &config, std::unique_ptr< SavedModelBundleFactory > *factory)