16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_warmup.h"
20 #include "google/protobuf/wrappers.pb.h"
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/signature_constants.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
26 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
27 #include "tensorflow/core/tfrt/utils/tensor_util.h"
28 #include "tsl/platform/path.h"
29 #include "tensorflow_serving/apis/classification.pb.h"
30 #include "tensorflow_serving/apis/inference.pb.h"
31 #include "tensorflow_serving/apis/input.pb.h"
32 #include "tensorflow_serving/apis/model.pb.h"
33 #include "tensorflow_serving/apis/predict.pb.h"
34 #include "tensorflow_serving/apis/prediction_log.pb.h"
35 #include "tensorflow_serving/apis/regression.pb.h"
36 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
37 #include "tensorflow_serving/servables/tensorflow/test_util/mock_tfrt_saved_model.h"
39 namespace tensorflow {
44 using ::testing::DoAll;
45 using ::testing::Return;
46 using ::testing::ReturnRef;
47 using ::testing::WithArgs;
49 class TFRTSavedModelWarmupOptionsTest :
public ::testing::TestWithParam<bool> {
51 bool EnableNumRequestIterations() {
return GetParam(); }
53 ModelWarmupOptions GetModelWarmupOptions() {
54 ModelWarmupOptions options;
55 if (EnableNumRequestIterations()) {
56 options.mutable_num_request_iterations()->set_value(2);
61 int GetNumRequestIterations() {
62 if (EnableNumRequestIterations()) {
69 TEST_P(TFRTSavedModelWarmupOptionsTest, MixedWarmupData) {
70 string base_path = io::JoinPath(testing::TmpDir(),
"MixedWarmupData");
71 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
72 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
73 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
74 internal::WarmupConsts::kRequestsFileName);
76 int num_warmup_records = 10;
77 std::vector<string> warmup_records;
78 TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
79 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
81 std::unique_ptr<test_util::MockSavedModel> saved_model(
82 (
new test_util::MockSavedModel()));
83 tfrt::internal::Signature predict_signature;
84 predict_signature.input_names = {kPredictInputs};
85 tfrt::TensorSpec spec(tensorflow::DT_STRING);
86 predict_signature.input_specs = {spec};
87 predict_signature.output_names = {kPredictOutputs};
88 tfrt::FunctionMetadata predict_function_metadata(&predict_signature);
89 EXPECT_CALL(*saved_model, GetFunctionMetadata(kPredictMethodName))
90 .WillRepeatedly(Return(predict_function_metadata));
92 tfrt::internal::Signature classify_signature;
93 classify_signature.input_names = {kClassifyInputs};
94 classify_signature.output_names = {kClassifyOutputClasses,
95 kClassifyOutputScores};
96 tfrt::FunctionMetadata classify_function_metadata(&classify_signature);
97 EXPECT_CALL(*saved_model, GetFunctionMetadata(kClassifyMethodName))
98 .WillRepeatedly(Return(classify_function_metadata));
100 tfrt::internal::Signature regress_signature;
101 regress_signature.input_names = {kRegressInputs};
102 regress_signature.output_names = {kRegressOutputs};
103 tfrt::FunctionMetadata regress_function_metadata(®ress_signature);
104 EXPECT_CALL(*saved_model, GetFunctionMetadata(kRegressMethodName))
105 .WillRepeatedly(Return(regress_function_metadata));
107 MetaGraphDef meta_graph_def;
108 AddSignatures(&meta_graph_def);
109 EXPECT_CALL(*saved_model, GetMetaGraphDef())
110 .WillRepeatedly(ReturnRef(meta_graph_def));
112 Tensor scores(DT_FLOAT, TensorShape({1, 1}));
113 Tensor classes(DT_STRING, TensorShape({1, 1}));
115 EXPECT_CALL(*saved_model, Run(_, ::testing::Eq(kPredictMethodName),
116 ::testing::An<absl::Span<const Tensor>>(), _))
117 .Times(num_warmup_records * GetNumRequestIterations())
119 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
120 output_tensors->push_back(scores);
122 Return(absl::OkStatus())));
124 EXPECT_CALL(*saved_model, Run(_, ::testing::Eq(kRegressMethodName),
125 ::testing::An<absl::Span<const Tensor>>(), _))
126 .Times(num_warmup_records * GetNumRequestIterations())
128 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
129 output_tensors->push_back(scores);
131 Return(absl::OkStatus())));
133 EXPECT_CALL(*saved_model, Run(_, ::testing::Eq(kClassifyMethodName),
134 ::testing::An<absl::Span<const Tensor>>(), _))
135 .Times(num_warmup_records * GetNumRequestIterations())
137 DoAll(WithArgs<3>([&](std::vector<Tensor>* output_tensors) {
138 output_tensors->push_back(classes);
139 output_tensors->push_back(scores);
141 Return(absl::OkStatus())));
143 EXPECT_CALL(*saved_model, RunMultipleSignatures(_, _, _, _))
144 .Times(num_warmup_records * GetNumRequestIterations())
145 .WillRepeatedly(DoAll(
146 WithArgs<3>([&](std::vector<std::vector<Tensor>>* output_tensors) {
147 output_tensors->resize(2);
148 (*output_tensors)[0].push_back(scores);
149 (*output_tensors)[1].push_back(classes);
150 (*output_tensors)[1].push_back(scores);
152 Return(absl::OkStatus())));
154 TF_EXPECT_OK(RunSavedModelWarmup(GetModelWarmupOptions(), base_path,
160 TEST_P(TFRTSavedModelWarmupOptionsTest, PredictStreamedWarmupData) {
161 std::string base_path =
162 tsl::io::JoinPath(testing::TmpDir(),
"PredictStreamedWarmupData");
163 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
164 tsl::io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
166 tsl::io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
167 internal::WarmupConsts::kRequestsFileName);
169 int num_warmup_records = 10;
170 std::vector<std::string> warmup_records;
172 AddToWarmupData(&warmup_records, PredictionLog::kPredictStreamedLog));
173 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
175 auto saved_model = std::make_unique<test_util::MockSavedModel>();
177 tfrt::internal::Signature signature;
178 signature.input_names = {kPredictInputs};
179 signature.input_specs = {tfrt::TensorSpec(tensorflow::DT_STRING)};
180 tfrt::FunctionMetadata function_metadata(&signature);
181 EXPECT_CALL(*saved_model, GetFunctionMetadata(kPredictMethodName))
182 .WillRepeatedly(Return(function_metadata));
184 MetaGraphDef meta_graph_def;
185 (*meta_graph_def.mutable_signature_def())[kPredictMethodName] =
186 CreateSignatureDef(kPredictMethodName, {kPredictInputs}, {});
188 EXPECT_CALL(*saved_model, GetMetaGraphDef())
189 .WillRepeatedly(ReturnRef(meta_graph_def));
193 Run(::testing::Field(
194 &tfrt_stub::GraphExecutionRunOptions::streamed_output_callback,
195 ::testing::NotNull()),
196 ::testing::Eq(kPredictMethodName),
197 ::testing::An<absl::Span<const Tensor>>(), _))
198 .Times(num_warmup_records * GetNumRequestIterations())
199 .WillRepeatedly(Return(absl::OkStatus()));
201 TF_EXPECT_OK(RunSavedModelWarmup(GetModelWarmupOptions(), base_path,
207 INSTANTIATE_TEST_SUITE_P(WarmupOptions, TFRTSavedModelWarmupOptionsTest,
210 TEST(TFRTSavedModelWarmupTest, UnsupportedLogType) {
211 string base_path = io::JoinPath(testing::TmpDir(),
"UnsupportedLogType");
212 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
213 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
214 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
215 internal::WarmupConsts::kRequestsFileName);
217 std::vector<string> warmup_records;
219 PredictionLog prediction_log;
221 PopulatePredictionLog(&prediction_log, PredictionLog::kSessionRunLog));
222 warmup_records.push_back(prediction_log.SerializeAsString());
223 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
225 std::unique_ptr<test_util::MockSavedModel> saved_model(
226 (
new test_util::MockSavedModel()));
227 MetaGraphDef meta_graph_def;
228 AddSignatures(&meta_graph_def);
229 EXPECT_CALL(*saved_model, GetMetaGraphDef())
230 .WillRepeatedly(ReturnRef(meta_graph_def));
231 const Status status = RunSavedModelWarmup(
232 ModelWarmupOptions(), base_path,
234 true, saved_model.get());
235 ASSERT_FALSE(status.ok());
236 EXPECT_EQ(::tensorflow::error::UNIMPLEMENTED, status.code()) << status;
237 EXPECT_THAT(status.ToString(),
238 ::testing::HasSubstr(
"Unsupported log_type for warmup"));
241 TEST(TFRTSavedModelWarmupTest, SkipWarmupRequest) {
242 string base_path = io::JoinPath(testing::TmpDir(),
"SkipWarmupRequest");
243 TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
244 io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
245 string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
246 internal::WarmupConsts::kRequestsFileName);
248 int num_warmup_records = 10;
249 std::vector<string> warmup_records;
250 TF_ASSERT_OK(AddMixedWarmupData(
251 &warmup_records, {PredictionLog::kRegressLog, PredictionLog::kClassifyLog,
252 PredictionLog::kPredictLog}));
253 TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
255 std::unique_ptr<test_util::MockSavedModel> saved_model(
256 (
new test_util::MockSavedModel()));
257 EXPECT_CALL(*saved_model, GetFunctionMetadata(kPredictMethodName)).Times(0);
258 EXPECT_CALL(*saved_model, GetFunctionMetadata(kClassifyMethodName)).Times(0);
259 EXPECT_CALL(*saved_model, GetFunctionMetadata(kRegressMethodName)).Times(0);
260 MetaGraphDef meta_graph_def;
261 AddSignatures(&meta_graph_def);
262 EXPECT_CALL(*saved_model, GetMetaGraphDef())
263 .WillRepeatedly(ReturnRef(meta_graph_def));
265 TF_EXPECT_OK(RunSavedModelWarmup(ModelWarmupOptions(), base_path,