16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_TEST_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_BUNDLE_FACTORY_TEST_H_
24 #include "google/protobuf/wrappers.pb.h"
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow_serving/resources/resources.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
33 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
34 #include "tensorflow_serving/test_util/test_util.h"
36 namespace tensorflow {
40 using test_util::EqualsProto;
46 : export_dir_(export_dir) {}
52 void TestBasic()
const {
53 const SessionBundleConfig config = GetSessionBundleConfig();
54 std::unique_ptr<Session> session;
55 if (ExpectCreateBundleFailure()) {
56 EXPECT_FALSE(CreateSession(config, &session).ok());
59 TF_ASSERT_OK(CreateSession(config, &session));
60 TestSingleRequest(session.get());
63 int GetTotalBatchesProcessed()
const {
65 "/tensorflow/serving/batching_session/wrapped_run_count");
66 auto* collection_registry = monitoring::CollectionRegistry::Default();
67 monitoring::CollectionRegistry::CollectMetricsOptions options;
68 const std::unique_ptr<monitoring::CollectedMetrics> collected_metrics =
69 collection_registry->CollectMetrics(options);
71 const auto& point_set_map = collected_metrics->point_set_map;
72 if (point_set_map.find(label) == point_set_map.end())
return 0;
73 const monitoring::PointSet& lps = *point_set_map.at(label);
74 for (
int i = 0; i < lps.points.size(); ++i) {
75 total_count += lps.points[i]->int64_value;
77 return static_cast<int>(total_count);
80 void TestBatching(
const BatchingParameters& params,
81 bool enable_per_model_batching_params,
82 int input_request_batch_size,
int batch_size)
const {
83 SessionBundleConfig config = GetSessionBundleConfig();
84 config.set_enable_per_model_batching_params(
85 enable_per_model_batching_params);
86 BatchingParameters* batching_params = config.mutable_batching_parameters();
87 *batching_params = params;
94 batching_params->mutable_max_enqueued_batches()->set_value(INT_MAX);
100 batching_params->mutable_batch_timeout_micros()->set_value(INT_MAX);
102 std::unique_ptr<Session> session;
103 if (ExpectCreateBundleFailure()) {
104 EXPECT_FALSE(CreateSession(config, &session).ok());
107 TF_ASSERT_OK(CreateSession(config, &session));
109 const int num_requests = 10;
110 const int expected_batches =
111 (input_request_batch_size * num_requests) / batch_size;
112 const int orig_batches_processed = GetTotalBatchesProcessed();
113 TestMultipleRequests(session.get(), num_requests, input_request_batch_size);
114 EXPECT_EQ(orig_batches_processed + expected_batches,
115 GetTotalBatchesProcessed());
118 template <
class FactoryType>
119 void TestEstimateResourceRequirementWithGoodExport(
120 double total_file_size)
const {
121 const SessionBundleConfig config = GetSessionBundleConfig();
122 std::unique_ptr<FactoryType> factory;
123 TF_ASSERT_OK(FactoryType::Create(config, &factory));
124 ResourceAllocation actual;
125 TF_ASSERT_OK(factory->EstimateResourceRequirement(export_dir_, &actual));
127 ResourceAllocation expected = GetExpectedResourceEstimate(total_file_size);
128 EXPECT_THAT(actual, EqualsProto(expected));
131 void TestRunOptions()
const {
132 if (!IsRunOptionsSupported())
return;
134 SessionBundleConfig config = GetSessionBundleConfig();
138 config.mutable_session_config()->add_session_inter_op_thread_pool();
139 config.mutable_session_config()
140 ->add_session_inter_op_thread_pool()
141 ->set_num_threads(1);
144 config.mutable_session_run_load_threadpool_index()->set_value(1);
148 std::unique_ptr<Session> session;
149 if (ExpectCreateBundleFailure()) {
150 EXPECT_FALSE(CreateSession(config, &session).ok());
153 TF_ASSERT_OK(CreateSession(config, &session));
155 TestSingleRequest(session.get());
158 void TestRunOptionsError()
const {
159 if (!IsRunOptionsSupported())
return;
162 SessionBundleConfig config = GetSessionBundleConfig();
165 config.mutable_session_run_load_threadpool_index()->set_value(100);
169 std::unique_ptr<Session> session;
170 EXPECT_FALSE(CreateSession(config, &session).ok());
179 virtual Status CreateSession(
const SessionBundleConfig &config,
180 std::unique_ptr<Session> *session)
const = 0;
183 virtual SessionBundleConfig GetSessionBundleConfig()
const {
184 return SessionBundleConfig();
188 virtual bool IsRunOptionsSupported()
const {
return true; }
191 virtual bool ExpectCreateBundleFailure()
const {
return false; }