16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup.h"
20 #include "google/protobuf/wrappers.pb.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/cc/saved_model/constants.h"
24 #include "tensorflow/cc/saved_model/signature_constants.h"
25 #include "tensorflow/core/example/example.pb.h"
26 #include "tensorflow/core/example/feature.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/lib/io/record_writer.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/threadpool_options.h"
35 #include "tensorflow_serving/apis/classification.pb.h"
36 #include "tensorflow_serving/apis/inference.pb.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/predict.pb.h"
40 #include "tensorflow_serving/apis/prediction_log.pb.h"
41 #include "tensorflow_serving/apis/regression.pb.h"
42 #include "tensorflow_serving/core/test_util/mock_session.h"
43 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
44 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
46 namespace tensorflow {
51 using test_util::MockSession;
53 using ::testing::DoAll;
54 using ::testing::Return;
55 using ::testing::SetArgPointee;
56 using ::testing::SizeIs;
58 class SavedModelBundleWarmupOptionsTest
59 :
public ::testing::TestWithParam<bool> {
61 bool EnableNumRequestIterations() {
return GetParam(); }
63 ModelWarmupOptions GetModelWarmupOptions() {
64 ModelWarmupOptions options;
65 if (EnableNumRequestIterations()) {
66 options.mutable_num_request_iterations()->set_value(2);
71 int GetNumRequestIterations() {
72 if (EnableNumRequestIterations()) {
79 TEST_P(SavedModelBundleWarmupOptionsTest, MixedWarmupData) {
80 string base_path = io::JoinPath(testing::TmpDir(),
"MixedWarmupData");
81 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
82 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
83 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
84 internal::WarmupConsts::kRequestsFileName);
86 int num_warmup_records = 10;
87 std::vector<string> warmup_records;
88 TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
89 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
90 SavedModelBundle saved_model_bundle;
91 AddSignatures(&saved_model_bundle.meta_graph_def);
92 MockSession* mock =
new MockSession;
93 saved_model_bundle.session.reset(mock);
94 Tensor scores(DT_FLOAT, TensorShape({1, 1}));
95 Tensor classes(DT_STRING, TensorShape({1, 1}));
97 EXPECT_CALL(*mock, Run(_, _, SizeIs(1), _, _, _, _))
98 .Times(num_warmup_records * 2 * GetNumRequestIterations())
99 .WillRepeatedly(DoAll(SetArgPointee<4>(std::vector<Tensor>({scores})),
100 Return(absl::OkStatus())));
102 EXPECT_CALL(*mock, Run(_, _, SizeIs(2), _, _, _, _))
103 .Times(num_warmup_records * GetNumRequestIterations())
105 DoAll(SetArgPointee<4>(std::vector<Tensor>({classes, scores})),
106 Return(absl::OkStatus())));
108 EXPECT_CALL(*mock, Run(_, _, SizeIs(3), _, _, _, _))
109 .Times(num_warmup_records * GetNumRequestIterations())
110 .WillRepeatedly(DoAll(
111 SetArgPointee<4>(std::vector<Tensor>({classes, scores, scores})),
112 Return(absl::OkStatus())));
113 TF_EXPECT_OK(RunSavedModelWarmup(GetModelWarmupOptions(), RunOptions(),
114 base_path, &saved_model_bundle));
116 INSTANTIATE_TEST_SUITE_P(WarmupOptions, SavedModelBundleWarmupOptionsTest,
119 TEST(SavedModelBundleWarmupTest, UnsupportedLogType_SessionRun) {
120 string base_path = io::JoinPath(testing::TmpDir(),
"SessionRun");
121 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
122 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
123 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
124 internal::WarmupConsts::kRequestsFileName);
126 std::vector<string> warmup_records;
128 TF_ASSERT_OK(AddToWarmupData(&warmup_records, PredictionLog::kSessionRunLog));
129 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
130 SavedModelBundle saved_model_bundle;
131 AddSignatures(&saved_model_bundle.meta_graph_def);
132 MockSession* mock =
new MockSession;
133 saved_model_bundle.session.reset(mock);
134 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
135 .WillRepeatedly(Return(absl::OkStatus()));
136 const Status status = RunSavedModelWarmup(ModelWarmupOptions(), RunOptions(),
137 base_path, &saved_model_bundle);
138 ASSERT_FALSE(status.ok());
139 EXPECT_EQ(::tensorflow::error::UNIMPLEMENTED, status.code()) << status;
140 EXPECT_THAT(status.ToString(),
141 ::testing::HasSubstr(
"Unsupported log_type for warmup"));
144 TEST(SavedModelBundleWarmupTest, UnsupportedLogType_PredictStreamed) {
145 string base_path = io::JoinPath(testing::TmpDir(),
"PredictStreamed");
146 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
147 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
148 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
149 internal::WarmupConsts::kRequestsFileName);
151 std::vector<string> warmup_records;
154 AddToWarmupData(&warmup_records, PredictionLog::kPredictStreamedLog));
155 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
156 SavedModelBundle saved_model_bundle;
157 AddSignatures(&saved_model_bundle.meta_graph_def);
158 MockSession* mock =
new MockSession;
159 saved_model_bundle.session.reset(mock);
160 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
161 .WillRepeatedly(Return(absl::OkStatus()));
162 const Status status = RunSavedModelWarmup(ModelWarmupOptions(), RunOptions(),
163 base_path, &saved_model_bundle);
164 ASSERT_FALSE(status.ok());
165 EXPECT_EQ(::tensorflow::error::UNIMPLEMENTED, status.code()) << status;
166 EXPECT_THAT(status.ToString(),
167 ::testing::HasSubstr(
"Unsupported log_type for warmup"));