16 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
23 #include "google/protobuf/wrappers.pb.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32 #include "tensorflow/core/protobuf/config.pb.h"
33 #include "tensorflow/core/public/session.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/public/version.h"
36 #include "tensorflow_serving/batching/batching_session.h"
37 #include "tensorflow_serving/resources/resources.pb.h"
38 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/session_bundle/session_bundle_util.h"
41 #include "tensorflow_serving/test_util/test_util.h"
42 #include "tensorflow_serving/util/test_util/mock_file_probing_env.h"
44 namespace tensorflow {
48 using test_util::EqualsProto;
50 using Batcher = SharedBatchScheduler<BatchingSessionTask>;
52 class MockSession :
public Session {
54 MOCK_METHOD(tensorflow::Status, Create, (
const GraphDef& graph), (
override));
55 MOCK_METHOD(tensorflow::Status, Extend, (
const GraphDef& graph), (
override));
56 MOCK_METHOD(tensorflow::Status, ListDevices,
57 (std::vector<DeviceAttributes> * response), (
override));
58 MOCK_METHOD(tensorflow::Status, Close, (), (
override));
60 Status Run(
const RunOptions& run_options,
61 const std::vector<std::pair<string, Tensor>>& inputs,
62 const std::vector<string>& output_tensor_names,
63 const std::vector<string>& target_node_names,
64 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override {
66 const auto& input = inputs[0].second.flat<
float>();
67 Tensor output(DT_FLOAT, inputs[0].second.shape());
68 test::FillFn<float>(&output,
69 [&](
int i) ->
float {
return input(i) / 2 + 2; });
70 outputs->push_back(output);
71 return absl::OkStatus();
75 Status Run(
const std::vector<std::pair<std::string, Tensor>>&,
76 const std::vector<std::string>&,
const std::vector<std::string>&,
77 std::vector<Tensor>* outputs)
override {
78 return errors::Unimplemented(
79 "Run with threadpool is not supported for this session.");
86 class BundleFactoryUtilTest :
public ::testing::Test {
88 BundleFactoryUtilTest() : export_dir_(test_util::GetTestSavedModelPath()) {}
90 virtual ~BundleFactoryUtilTest() =
default;
93 const string export_dir_;
96 TEST_F(BundleFactoryUtilTest, GetSessionOptions) {
97 SessionBundleConfig bundle_config;
99 constexpr
char kTarget[] =
"target";
100 bundle_config.set_session_target(kTarget);
101 ConfigProto *config_proto = bundle_config.mutable_session_config();
102 config_proto->set_allow_soft_placement(
true);
104 SessionOptions session_options = GetSessionOptions(bundle_config);
105 EXPECT_EQ(session_options.target, kTarget);
106 EXPECT_THAT(session_options.config, EqualsProto(*config_proto));
109 TEST_F(BundleFactoryUtilTest, GetRunOptions) {
110 SessionBundleConfig bundle_config;
113 bundle_config.mutable_session_run_load_threadpool_index()->set_value(1);
116 want.set_inter_op_thread_pool(1);
117 EXPECT_THAT(GetRunOptions(bundle_config), EqualsProto(want));
120 TEST_F(BundleFactoryUtilTest, WrapSession) {
121 SavedModelBundle bundle;
122 TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
123 {
"serve"}, &bundle));
124 TF_ASSERT_OK(WrapSession(&bundle.session));
125 test_util::TestSingleRequest(bundle.session.get());
128 TEST_F(BundleFactoryUtilTest, WrapSessionIgnoreThreadPoolOptions) {
129 std::unique_ptr<Session> session(
new MockSession);
131 TF_ASSERT_OK(WrapSessionIgnoreThreadPoolOptions(&session));
132 test_util::TestSingleRequest(session.get());
135 TEST_F(BundleFactoryUtilTest, WrapSessionForBatching) {
136 SavedModelBundle bundle;
137 TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
138 {
"serve"}, &bundle));
141 BatchingParameters batching_params;
142 batching_params.mutable_max_batch_size()->set_value(2);
143 batching_params.mutable_max_enqueued_batches()->set_value(INT_MAX);
145 std::shared_ptr<Batcher> batcher;
146 TF_ASSERT_OK(CreateBatchScheduler(batching_params, &batcher));
149 TF_ASSERT_OK(WrapSessionForBatching(batching_params, batcher,
150 {test_util::GetTestSessionSignature()},
154 test_util::TestMultipleRequests(bundle.session.get(), 10, 2);
157 TEST_F(BundleFactoryUtilTest, WrapSessionForBatchingConfigError) {
158 BatchingParameters batching_params;
159 batching_params.mutable_max_batch_size()->set_value(2);
162 batching_params.add_allowed_batch_sizes(1);
163 batching_params.add_allowed_batch_sizes(3);
165 std::shared_ptr<Batcher> batch_scheduler;
166 TF_ASSERT_OK(CreateBatchScheduler(batching_params, &batch_scheduler));
168 SavedModelBundle bundle;
169 TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(), export_dir_,
170 {
"serve"}, &bundle));
171 auto status = WrapSessionForBatching(batching_params, batch_scheduler,
172 {test_util::GetTestSessionSignature()},
174 ASSERT_TRUE(errors::IsInvalidArgument(status));
177 TEST_F(BundleFactoryUtilTest, GetPerModelBatchingParams) {
178 const BatchingParameters common_params =
179 test_util::CreateProto<BatchingParameters>(R
"(
180 allowed_batch_sizes: 8
181 allowed_batch_sizes: 16
182 max_batch_size { value: 16 })");
184 const string per_model_params_pbtxt(R
"(
185 allowed_batch_sizes: 8
186 allowed_batch_sizes: 16
187 allowed_batch_sizes: 128
188 max_batch_size { value: 128 })");
190 std::unique_ptr<WritableFile> file;
191 TF_ASSERT_OK(Env::Default()->NewWritableFile(
192 io::JoinPath(testing::TmpDir(), "/batching_params.pbtxt"), &file));
193 TF_ASSERT_OK(file->Append(per_model_params_pbtxt));
194 TF_ASSERT_OK(file->Close());
196 absl::optional<BatchingParameters> params;
197 TF_ASSERT_OK(GetPerModelBatchingParams(
"does/not/exists", common_params,
200 EXPECT_THAT(params.value(), test_util::EqualsProto(common_params));
203 ASSERT_TRUE(GetPerModelBatchingParams(
"does/not/exists", common_params,
208 TF_ASSERT_OK(GetPerModelBatchingParams(testing::TmpDir(), common_params,
211 EXPECT_THAT(params.value(), test_util::EqualsProto(common_params));
214 TF_ASSERT_OK(GetPerModelBatchingParams(testing::TmpDir(), common_params,
217 EXPECT_THAT(params.value(), test_util::EqualsProto(per_model_params_pbtxt));
220 TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithBadExport) {
221 ResourceAllocation resource_requirement;
222 const Status status = EstimateResourceFromPath(
223 "/a/bogus/export/dir",
224 false, &resource_requirement);
225 EXPECT_FALSE(status.ok());
228 TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithGoodExport) {
229 const double kTotalFileSize = test_util::GetTotalFileSize(
230 test_util::GetTestSavedModelBundleExportFiles());
231 ResourceAllocation expected =
232 test_util::GetExpectedResourceEstimate(kTotalFileSize);
234 ResourceAllocation actual;
235 TF_ASSERT_OK(EstimateResourceFromPath(
236 export_dir_,
false, &actual));
237 EXPECT_THAT(actual, EqualsProto(expected));
240 #ifdef PLATFORM_GOOGLE
244 void BM_HalfPlusTwo(benchmark::State& state) {
245 static Session* session;
246 if (state.thread_index() == 0) {
247 SavedModelBundle bundle;
248 TF_ASSERT_OK(LoadSavedModel(SessionOptions(), RunOptions(),
249 test_util::GetTestSavedModelPath(), {
"serve"},
251 TF_ASSERT_OK(WrapSession(&bundle.session));
252 session = bundle.session.release();
254 Tensor input = test::AsTensor<float>({1.0, 2.0, 3.0}, TensorShape({3}));
255 std::vector<Tensor> outputs;
256 for (
auto _ : state) {
258 TF_ASSERT_OK(session->Run({{
"x:0", input}}, {
"y:0"}, {}, &outputs));
261 BENCHMARK(BM_HalfPlusTwo)->UseRealTime()->ThreadRange(1, 64);