16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
21 #include "google/protobuf/wrappers.pb.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/example/example.pb.h"
27 #include "tensorflow/core/example/feature.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/kernels/batching_util/warmup.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/io/path.h"
33 #include "tensorflow/core/lib/io/record_writer.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/threadpool_options.h"
39 #include "tensorflow_serving/apis/classification.pb.h"
40 #include "tensorflow_serving/apis/inference.pb.h"
41 #include "tensorflow_serving/apis/input.pb.h"
42 #include "tensorflow_serving/apis/model.pb.h"
43 #include "tensorflow_serving/apis/predict.pb.h"
44 #include "tensorflow_serving/apis/prediction_log.pb.h"
45 #include "tensorflow_serving/apis/regression.pb.h"
46 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
47 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
49 namespace tensorflow {
54 constexpr absl::string_view kModelName =
"/ml/owner/model";
55 constexpr int64_t kModelVersion = 0;
56 constexpr int32_t kNumWarmupThreads = 3;
58 class SavedModelBundleWarmupUtilTest :
public ::testing::TestWithParam<bool> {
60 SavedModelBundleWarmupUtilTest() {}
62 bool ParallelWarmUp() {
return GetParam(); }
64 ModelWarmupOptions CreateModelWarmupOptions() {
65 ModelWarmupOptions options;
66 if (ParallelWarmUp()) {
67 options.set_model_name(std::string(kModelName));
68 options.set_model_version(kModelVersion);
69 options.mutable_num_model_warmup_threads()->set_value(kNumWarmupThreads);
74 bool LookupWarmupState()
const {
75 return GetGlobalWarmupStateRegistry().Lookup(
76 {std::string(kModelName), kModelVersion});
79 void FakeRunWarmupRequest() {
80 tensorflow::mutex_lock lock(mu_);
81 is_model_in_warmup_state_registry_ = LookupWarmupState();
82 warmup_request_counter_++;
85 bool is_model_in_warmup_state_registry() {
86 tensorflow::mutex_lock lock(mu_);
87 return is_model_in_warmup_state_registry_;
90 int warmup_request_counter() {
91 tensorflow::mutex_lock lock(mu_);
92 return warmup_request_counter_;
96 tensorflow::mutex mu_;
97 bool is_model_in_warmup_state_registry_ =
false;
98 int warmup_request_counter_ = 0;
101 TEST_P(SavedModelBundleWarmupUtilTest, WarmupStateRegistration) {
102 string base_path = io::JoinPath(testing::TmpDir(),
"WarmupStateRegistration");
103 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
104 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
105 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
106 internal::WarmupConsts::kRequestsFileName);
108 const int num_warmup_records = ParallelWarmUp() ? kNumWarmupThreads : 1;
109 std::vector<string> warmup_records;
111 AddMixedWarmupData(&warmup_records, {PredictionLog::kPredictLog}));
112 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
114 TF_ASSERT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
115 [
this](PredictionLog prediction_log) {
116 this->FakeRunWarmupRequest();
117 return absl::OkStatus();
119 EXPECT_EQ(warmup_request_counter(), num_warmup_records);
120 EXPECT_EQ(is_model_in_warmup_state_registry(), ParallelWarmUp());
123 EXPECT_FALSE(LookupWarmupState());
126 TEST_P(SavedModelBundleWarmupUtilTest, NoWarmupDataFile) {
127 string base_path = io::JoinPath(testing::TmpDir(),
"NoWarmupDataFile");
128 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
129 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
131 SavedModelBundle saved_model_bundle;
132 AddSignatures(&saved_model_bundle.meta_graph_def);
133 TF_EXPECT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
134 [
this](PredictionLog prediction_log) {
135 this->FakeRunWarmupRequest();
136 return absl::OkStatus();
138 EXPECT_EQ(warmup_request_counter(), 0);
141 TEST_P(SavedModelBundleWarmupUtilTest, WarmupDataFileEmpty) {
142 string base_path = io::JoinPath(testing::TmpDir(),
"WarmupDataFileEmpty");
143 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
144 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
145 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
146 internal::WarmupConsts::kRequestsFileName);
148 std::vector<string> warmup_records;
149 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 0));
150 SavedModelBundle saved_model_bundle;
151 AddSignatures(&saved_model_bundle.meta_graph_def);
152 TF_EXPECT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
153 [
this](PredictionLog prediction_log) {
154 this->FakeRunWarmupRequest();
155 return absl::OkStatus();
157 EXPECT_EQ(warmup_request_counter(), 0);
160 TEST_P(SavedModelBundleWarmupUtilTest, UnsupportedFileFormat) {
161 string base_path = io::JoinPath(testing::TmpDir(),
"UnsupportedFileFormat");
162 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
163 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
164 const string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
165 internal::WarmupConsts::kRequestsFileName);
167 std::vector<string> warmup_records;
169 PredictionLog prediction_log;
171 PopulatePredictionLog(&prediction_log, PredictionLog::kSessionRunLog));
172 warmup_records.push_back(prediction_log.SerializeAsString());
174 TF_ASSERT_OK(WriteWarmupDataAsSerializedProtos(fname, warmup_records, 10));
175 SavedModelBundle saved_model_bundle;
176 AddSignatures(&saved_model_bundle.meta_graph_def);
177 const Status status = RunSavedModelWarmup(
178 CreateModelWarmupOptions(), base_path,
179 [](PredictionLog prediction_log) {
return absl::OkStatus(); });
180 ASSERT_FALSE(status.ok());
181 EXPECT_EQ(::tensorflow::error::DATA_LOSS, status.code()) << status;
182 EXPECT_THAT(status.ToString(),
183 ::testing::HasSubstr(
184 "Please verify your warmup data is in TFRecord format"));
187 TEST_P(SavedModelBundleWarmupUtilTest, TooManyWarmupRecords) {
188 string base_path = io::JoinPath(testing::TmpDir(),
"TooManyWarmupRecords");
189 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
190 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
191 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
192 internal::WarmupConsts::kRequestsFileName);
194 std::vector<string> warmup_records;
195 TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
196 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records,
197 internal::WarmupConsts::kMaxNumRecords + 1));
198 SavedModelBundle saved_model_bundle;
199 AddSignatures(&saved_model_bundle.meta_graph_def);
200 const Status status = RunSavedModelWarmup(
201 CreateModelWarmupOptions(), base_path,
202 [](PredictionLog prediction_log) {
return absl::OkStatus(); });
203 ASSERT_FALSE(status.ok());
204 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
209 ::testing::HasSubstr(
"Number of warmup records exceeds the maximum"));
212 TEST_P(SavedModelBundleWarmupUtilTest, UnparsableRecord) {
213 string base_path = io::JoinPath(testing::TmpDir(),
"UnparsableRecord");
214 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
215 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
216 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
217 internal::WarmupConsts::kRequestsFileName);
219 std::vector<string> warmup_records = {
"malformed_record"};
220 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
221 SavedModelBundle saved_model_bundle;
222 const Status status = RunSavedModelWarmup(
223 CreateModelWarmupOptions(), base_path,
224 [](PredictionLog prediction_log) {
return absl::OkStatus(); });
225 ASSERT_FALSE(status.ok());
226 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
229 EXPECT_THAT(status.ToString(),
230 ::testing::HasSubstr(
"Failed to parse warmup record"));
233 TEST_P(SavedModelBundleWarmupUtilTest, RunFailure) {
234 string base_path = io::JoinPath(testing::TmpDir(),
"RunFailure");
235 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
236 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
237 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
238 internal::WarmupConsts::kRequestsFileName);
240 int num_warmup_records = 10;
241 std::vector<string> warmup_records;
242 TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
243 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
244 SavedModelBundle saved_model_bundle;
245 AddSignatures(&saved_model_bundle.meta_graph_def);
246 Status status = RunSavedModelWarmup(
247 CreateModelWarmupOptions(), base_path, [](PredictionLog prediction_log) {
248 return errors::InvalidArgument(
"Run failed");
250 ASSERT_FALSE(status.ok());
251 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
254 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"Run failed"));
256 INSTANTIATE_TEST_SUITE_P(ParallelWarmUp, SavedModelBundleWarmupUtilTest,