TensorFlow Serving C++ API Documentation
bundle_factory_test.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_TEST_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_TEST_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "google/protobuf/wrappers.pb.h"
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow_serving/resources/resources.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
33 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
34 #include "tensorflow_serving/test_util/test_util.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 namespace test_util {
39 
40 using test_util::EqualsProto;
41 
42 // The base class for SessionBundleFactoryTest and SavedModelBundleFactoryTest.
43 class BundleFactoryTest : public ::testing::Test {
44  public:
45  explicit BundleFactoryTest(const string &export_dir)
46  : export_dir_(export_dir) {}
47 
48  virtual ~BundleFactoryTest() = default;
49 
50  protected:
51  // Test functions to be used by subclasses.
52  void TestBasic() const {
53  const SessionBundleConfig config = GetSessionBundleConfig();
54  std::unique_ptr<Session> session;
55  if (ExpectCreateBundleFailure()) {
56  EXPECT_FALSE(CreateSession(config, &session).ok());
57  return;
58  }
59  TF_ASSERT_OK(CreateSession(config, &session));
60  TestSingleRequest(session.get());
61  }
62 
63  int GetTotalBatchesProcessed() const {
64  const string label(
65  "/tensorflow/serving/batching_session/wrapped_run_count");
66  auto* collection_registry = monitoring::CollectionRegistry::Default();
67  monitoring::CollectionRegistry::CollectMetricsOptions options;
68  const std::unique_ptr<monitoring::CollectedMetrics> collected_metrics =
69  collection_registry->CollectMetrics(options);
70  int total_count = 0;
71  const auto& point_set_map = collected_metrics->point_set_map;
72  if (point_set_map.find(label) == point_set_map.end()) return 0;
73  const monitoring::PointSet& lps = *point_set_map.at(label);
74  for (int i = 0; i < lps.points.size(); ++i) {
75  total_count += lps.points[i]->int64_value;
76  }
77  return static_cast<int>(total_count);
78  }
79 
80  void TestBatching(const BatchingParameters& params,
81  bool enable_per_model_batching_params,
82  int input_request_batch_size, int batch_size) const {
83  SessionBundleConfig config = GetSessionBundleConfig();
84  config.set_enable_per_model_batching_params(
85  enable_per_model_batching_params);
86  BatchingParameters* batching_params = config.mutable_batching_parameters();
87  *batching_params = params;
88 
89  //
90  // Tweak batching params further for testing.
91  //
92  // Set high value of max enqueued batches to prevent queue limits to be hit
93  // during testing that involves lot of requests.
94  batching_params->mutable_max_enqueued_batches()->set_value(INT_MAX);
95  // Set very high value of timeout to force full batches to be formed.
96  //
97  // The default (zero) value of the timeout causes batches to formed with
98  // [1..max_batch_size] size based on relative ordering of Run() calls. A
99  // large value causes deterministic fix batch size to be formed.
100  batching_params->mutable_batch_timeout_micros()->set_value(INT_MAX);
101 
102  std::unique_ptr<Session> session;
103  if (ExpectCreateBundleFailure()) {
104  EXPECT_FALSE(CreateSession(config, &session).ok());
105  return;
106  }
107  TF_ASSERT_OK(CreateSession(config, &session));
108 
109  const int num_requests = 10;
110  const int expected_batches =
111  (input_request_batch_size * num_requests) / batch_size;
112  const int orig_batches_processed = GetTotalBatchesProcessed();
113  TestMultipleRequests(session.get(), num_requests, input_request_batch_size);
114  EXPECT_EQ(orig_batches_processed + expected_batches,
115  GetTotalBatchesProcessed());
116  }
117 
118  template <class FactoryType>
119  void TestEstimateResourceRequirementWithGoodExport(
120  double total_file_size) const {
121  const SessionBundleConfig config = GetSessionBundleConfig();
122  std::unique_ptr<FactoryType> factory;
123  TF_ASSERT_OK(FactoryType::Create(config, &factory));
124  ResourceAllocation actual;
125  TF_ASSERT_OK(factory->EstimateResourceRequirement(export_dir_, &actual));
126 
127  ResourceAllocation expected = GetExpectedResourceEstimate(total_file_size);
128  EXPECT_THAT(actual, EqualsProto(expected));
129  }
130 
131  void TestRunOptions() const {
132  if (!IsRunOptionsSupported()) return;
133 
134  SessionBundleConfig config = GetSessionBundleConfig();
135 
136  // Configure the session-config with two threadpools. The first is setup
137  // with default settings. The second is explicitly setup with 1 thread.
138  config.mutable_session_config()->add_session_inter_op_thread_pool();
139  config.mutable_session_config()
140  ->add_session_inter_op_thread_pool()
141  ->set_num_threads(1);
142 
143  // Set the threadpool index to use for session-run calls to 1.
144  config.mutable_session_run_load_threadpool_index()->set_value(1);
145 
146  // Since the session_run_load_threadpool_index in the config is set, the
147  // session-bundle should be loaded successfully from path with RunOptions.
148  std::unique_ptr<Session> session;
149  if (ExpectCreateBundleFailure()) {
150  EXPECT_FALSE(CreateSession(config, &session).ok());
151  return;
152  }
153  TF_ASSERT_OK(CreateSession(config, &session));
154 
155  TestSingleRequest(session.get());
156  }
157 
158  void TestRunOptionsError() const {
159  if (!IsRunOptionsSupported()) return;
160 
161  // Session bundle config with the default global threadpool.
162  SessionBundleConfig config = GetSessionBundleConfig();
163 
164  // Invalid threadpool index to use for session-run calls.
165  config.mutable_session_run_load_threadpool_index()->set_value(100);
166 
167  // Since RunOptions used in the session run calls refers to an invalid
168  // threadpool index, load session bundle from path should fail.
169  std::unique_ptr<Session> session;
170  EXPECT_FALSE(CreateSession(config, &session).ok());
171  }
172 
173  // Test data path, to be initialized to point at a SessionBundle export or
174  // SavedModel of half-plus-two.
175  string export_dir_;
176 
177  private:
178  // Creates a Session with the given configuration and export path.
179  virtual Status CreateSession(const SessionBundleConfig &config,
180  std::unique_ptr<Session> *session) const = 0;
181 
182  // Returns a SessionBundleConfig.
183  virtual SessionBundleConfig GetSessionBundleConfig() const {
184  return SessionBundleConfig();
185  }
186 
187  // Returns true if RunOptions is supported by underlying session.
188  virtual bool IsRunOptionsSupported() const { return true; }
189 
190  // Returns true if CreateBundle is expected to fail.
191  virtual bool ExpectCreateBundleFailure() const { return false; }
192 };
193 
194 } // namespace test_util
195 } // namespace serving
196 } // namespace tensorflow
197 
198 #endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_TEST_H_