TensorFlow Serving C++ API Documentation
bundle_factory_util_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "google/protobuf/wrappers.pb.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32 #include "tensorflow/core/protobuf/config.pb.h"
33 #include "tensorflow/core/public/session.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/public/version.h"
36 #include "tensorflow_serving/batching/batching_session.h"
37 #include "tensorflow_serving/resources/resources.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/session_bundle/session_bundle_util.h"
41 #include "tensorflow_serving/test_util/test_util.h"
42 #include "tensorflow_serving/util/test_util/mock_file_probing_env.h"
43 
44 namespace tensorflow {
45 namespace serving {
46 namespace {
47 
48 using test_util::EqualsProto;
49 
50 using Batcher = SharedBatchScheduler<BatchingSessionTask>;
51 
52 class MockSession : public Session {
53  public:
54  MOCK_METHOD(tensorflow::Status, Create, (const GraphDef& graph), (override));
55  MOCK_METHOD(tensorflow::Status, Extend, (const GraphDef& graph), (override));
56  MOCK_METHOD(tensorflow::Status, ListDevices,
57  (std::vector<DeviceAttributes> * response), (override));
58  MOCK_METHOD(tensorflow::Status, Close, (), (override));
59 
60  Status Run(const RunOptions& run_options,
61  const std::vector<std::pair<string, Tensor>>& inputs,
62  const std::vector<string>& output_tensor_names,
63  const std::vector<string>& target_node_names,
64  std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
65  // half plus two: output should be input / 2 + 2.
66  const auto& input = inputs[0].second.flat<float>();
67  Tensor output(DT_FLOAT, inputs[0].second.shape());
68  test::FillFn<float>(&output,
69  [&](int i) -> float { return input(i) / 2 + 2; });
70  outputs->push_back(output);
71  return absl::OkStatus();
72  }
73 
74  // Unused, but we need to provide a definition (virtual = 0).
75  Status Run(const std::vector<std::pair<std::string, Tensor>>&,
76  const std::vector<std::string>&, const std::vector<std::string>&,
77  std::vector<Tensor>* outputs) override {
78  return errors::Unimplemented(
79  "Run with threadpool is not supported for this session.");
80  }
81 
82  // NOTE: The default definition for Run(...) with threading options already
83  // returns errors::Unimplemented.
84 };
85 
86 class BundleFactoryUtilTest : public ::testing::Test {
87  protected:
88  BundleFactoryUtilTest() : export_dir_(test_util::GetTestSavedModelPath()) {}
89 
90  virtual ~BundleFactoryUtilTest() = default;
91 
92  // Test data path, to be initialized to point at an export of half-plus-two.
93  const string export_dir_;
94 };
95 
96 TEST_F(BundleFactoryUtilTest, GetSessionOptions) {
97  SessionBundleConfig bundle_config;
98 
99  constexpr char kTarget[] = "target";
100  bundle_config.set_session_target(kTarget);
101  ConfigProto *config_proto = bundle_config.mutable_session_config();
102  config_proto->set_allow_soft_placement(true);
103 
104  SessionOptions session_options = GetSessionOptions(bundle_config);
105  EXPECT_EQ(session_options.target, kTarget);
106  EXPECT_THAT(session_options.config, EqualsProto(*config_proto));
107 }
108 
109 TEST_F(BundleFactoryUtilTest, GetRunOptions) {
110  SessionBundleConfig bundle_config;
111 
112  // Set the threadpool index to use for session-run calls to 1.
113  bundle_config.mutable_session_run_load_threadpool_index()->set_value(1);
114 
115  RunOptions want;
116  want.set_inter_op_thread_pool(1);
117  EXPECT_THAT(GetRunOptions(bundle_config), EqualsProto(want));
118 }
119 
120 TEST_F(BundleFactoryUtilTest, WrapSession) {
121  SavedModelBundle bundle;
122  TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
123  {"serve"}, &bundle));
124  TF_ASSERT_OK(WrapSession(&bundle.session));
125  test_util::TestSingleRequest(bundle.session.get());
126 }
127 
128 TEST_F(BundleFactoryUtilTest, WrapSessionIgnoreThreadPoolOptions) {
129  std::unique_ptr<Session> session(new MockSession);
130 
131  TF_ASSERT_OK(WrapSessionIgnoreThreadPoolOptions(&session));
132  test_util::TestSingleRequest(session.get());
133 }
134 
135 TEST_F(BundleFactoryUtilTest, WrapSessionForBatching) {
136  SavedModelBundle bundle;
137  TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
138  {"serve"}, &bundle));
139 
140  // Create BatchingParameters and batch scheduler.
141  BatchingParameters batching_params;
142  batching_params.mutable_max_batch_size()->set_value(2);
143  batching_params.mutable_max_enqueued_batches()->set_value(INT_MAX);
144 
145  std::shared_ptr<Batcher> batcher;
146  TF_ASSERT_OK(CreateBatchScheduler(batching_params, &batcher));
147 
148  // Wrap the session.
149  TF_ASSERT_OK(WrapSessionForBatching(batching_params, batcher,
150  {test_util::GetTestSessionSignature()},
151  &bundle.session));
152 
153  // Run multiple requests concurrently. They should be executed as 5 batches.
154  test_util::TestMultipleRequests(bundle.session.get(), 10, 2);
155 }
156 
157 TEST_F(BundleFactoryUtilTest, WrapSessionForBatchingConfigError) {
158  BatchingParameters batching_params;
159  batching_params.mutable_max_batch_size()->set_value(2);
160  // The last entry in 'allowed_batch_sizes' is supposed to equal
161  // 'max_batch_size'. Let's violate that constraint and ensure we get an error.
162  batching_params.add_allowed_batch_sizes(1);
163  batching_params.add_allowed_batch_sizes(3);
164 
165  std::shared_ptr<Batcher> batch_scheduler;
166  TF_ASSERT_OK(CreateBatchScheduler(batching_params, &batch_scheduler));
167 
168  SavedModelBundle bundle;
169  TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
170  {"serve"}, &bundle));
171  auto status = WrapSessionForBatching(batching_params, batch_scheduler,
172  {test_util::GetTestSessionSignature()},
173  &bundle.session);
174  ASSERT_TRUE(errors::IsInvalidArgument(status));
175 }
176 
177 TEST_F(BundleFactoryUtilTest, GetPerModelBatchingParams) {
178  const BatchingParameters common_params =
179  test_util::CreateProto<BatchingParameters>(R"(
180  allowed_batch_sizes: 8
181  allowed_batch_sizes: 16
182  max_batch_size { value: 16 })");
183 
184  const string per_model_params_pbtxt(R"(
185  allowed_batch_sizes: 8
186  allowed_batch_sizes: 16
187  allowed_batch_sizes: 128
188  max_batch_size { value: 128 })");
189 
190  std::unique_ptr<WritableFile> file;
191  TF_ASSERT_OK(Env::Default()->NewWritableFile(
192  io::JoinPath(testing::TmpDir(), "/batching_params.pbtxt"), &file));
193  TF_ASSERT_OK(file->Append(per_model_params_pbtxt));
194  TF_ASSERT_OK(file->Close());
195 
196  absl::optional<BatchingParameters> params;
197  TF_ASSERT_OK(GetPerModelBatchingParams("does/not/exists", common_params,
198  /*per_model_configured=*/false,
199  &params));
200  EXPECT_THAT(params.value(), test_util::EqualsProto(common_params));
201 
202  params.reset();
203  ASSERT_TRUE(GetPerModelBatchingParams("does/not/exists", common_params,
204  /*per_model_configured=*/true, &params)
205  .ok());
206 
207  params.reset();
208  TF_ASSERT_OK(GetPerModelBatchingParams(testing::TmpDir(), common_params,
209  /*per_model_configured=*/false,
210  &params));
211  EXPECT_THAT(params.value(), test_util::EqualsProto(common_params));
212 
213  params.reset();
214  TF_ASSERT_OK(GetPerModelBatchingParams(testing::TmpDir(), common_params,
215  /*per_model_configured=*/true,
216  &params));
217  EXPECT_THAT(params.value(), test_util::EqualsProto(per_model_params_pbtxt));
218 }
219 
220 TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithBadExport) {
221  ResourceAllocation resource_requirement;
222  const Status status = EstimateResourceFromPath(
223  "/a/bogus/export/dir",
224  /*use_validation_result=*/false, &resource_requirement);
225  EXPECT_FALSE(status.ok());
226 }
227 
228 TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithGoodExport) {
229  const double kTotalFileSize = test_util::GetTotalFileSize(
230  test_util::GetTestSavedModelBundleExportFiles());
231  ResourceAllocation expected =
232  test_util::GetExpectedResourceEstimate(kTotalFileSize);
233 
234  ResourceAllocation actual;
235  TF_ASSERT_OK(EstimateResourceFromPath(
236  export_dir_, /*use_validation_result=*/false, &actual));
237  EXPECT_THAT(actual, EqualsProto(expected));
238 }
239 
240 #ifdef PLATFORM_GOOGLE
241 // This benchmark relies on https://github.com/google/benchmark features,
242 // not available in open-sourced TF codebase.
243 
244 void BM_HalfPlusTwo(benchmark::State& state) {
245  static Session* session;
246  if (state.thread_index() == 0) {
247  SavedModelBundle bundle;
248  TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(),
249  test_util::GetTestSavedModelPath(), {"serve"},
250  &bundle));
251  TF_ASSERT_OK(WrapSession(&bundle.session));
252  session = bundle.session.release();
253  }
254  Tensor input = test::AsTensor<float>({1.0, 2.0, 3.0}, TensorShape({3}));
255  std::vector<Tensor> outputs;
256  for (auto _ : state) {
257  outputs.clear();
258  TF_ASSERT_OK(session->Run({{"x:0", input}}, {"y:0"}, {}, &outputs));
259  }
260 }
261 BENCHMARK(BM_HalfPlusTwo)->UseRealTime()->ThreadRange(1, 64);
262 
263 #endif // PLATFORM_GOOGLE
264 
265 } // namespace
266 } // namespace serving
267 } // namespace tensorflow