TensorFlow Serving C++ API Documentation
saved_model_warmup_util_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "google/protobuf/wrappers.pb.h"
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/example/example.pb.h"
27 #include "tensorflow/core/example/feature.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/kernels/batching_util/warmup.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/io/path.h"
33 #include "tensorflow/core/lib/io/record_writer.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/threadpool_options.h"
39 #include "tensorflow_serving/apis/classification.pb.h"
40 #include "tensorflow_serving/apis/inference.pb.h"
41 #include "tensorflow_serving/apis/input.pb.h"
42 #include "tensorflow_serving/apis/model.pb.h"
43 #include "tensorflow_serving/apis/predict.pb.h"
44 #include "tensorflow_serving/apis/prediction_log.pb.h"
45 #include "tensorflow_serving/apis/regression.pb.h"
46 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
47 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
48 
49 namespace tensorflow {
50 namespace serving {
51 namespace internal {
52 namespace {
53 
54 constexpr absl::string_view kModelName = "/ml/owner/model";
55 constexpr int64_t kModelVersion = 0;
56 constexpr int32_t kNumWarmupThreads = 3;
57 
58 class SavedModelBundleWarmupUtilTest : public ::testing::TestWithParam<bool> {
59  protected:
60  SavedModelBundleWarmupUtilTest() {}
61 
62  bool ParallelWarmUp() { return GetParam(); }
63 
64  ModelWarmupOptions CreateModelWarmupOptions() {
65  ModelWarmupOptions options;
66  if (ParallelWarmUp()) {
67  options.set_model_name(std::string(kModelName));
68  options.set_model_version(kModelVersion);
69  options.mutable_num_model_warmup_threads()->set_value(kNumWarmupThreads);
70  }
71  return options;
72  }
73 
74  bool LookupWarmupState() const {
75  return GetGlobalWarmupStateRegistry().Lookup(
76  {std::string(kModelName), kModelVersion});
77  }
78 
79  void FakeRunWarmupRequest() {
80  tensorflow::mutex_lock lock(mu_);
81  is_model_in_warmup_state_registry_ = LookupWarmupState();
82  warmup_request_counter_++;
83  }
84 
85  bool is_model_in_warmup_state_registry() {
86  tensorflow::mutex_lock lock(mu_);
87  return is_model_in_warmup_state_registry_;
88  }
89 
90  int warmup_request_counter() {
91  tensorflow::mutex_lock lock(mu_);
92  return warmup_request_counter_;
93  }
94 
95  private:
96  tensorflow::mutex mu_;
97  bool is_model_in_warmup_state_registry_ = false;
98  int warmup_request_counter_ = 0;
99 };
100 
101 TEST_P(SavedModelBundleWarmupUtilTest, WarmupStateRegistration) {
102  string base_path = io::JoinPath(testing::TmpDir(), "WarmupStateRegistration");
103  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
104  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
105  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
106  internal::WarmupConsts::kRequestsFileName);
107 
108  const int num_warmup_records = ParallelWarmUp() ? kNumWarmupThreads : 1;
109  std::vector<string> warmup_records;
110  TF_ASSERT_OK(
111  AddMixedWarmupData(&warmup_records, {PredictionLog::kPredictLog}));
112  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
113 
114  TF_ASSERT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
115  [this](PredictionLog prediction_log) {
116  this->FakeRunWarmupRequest();
117  return absl::OkStatus();
118  }));
119  EXPECT_EQ(warmup_request_counter(), num_warmup_records);
120  EXPECT_EQ(is_model_in_warmup_state_registry(), ParallelWarmUp());
121  // The model should be unregistered from the WarmupStateRegistry after
122  // warm-up.
123  EXPECT_FALSE(LookupWarmupState());
124 }
125 
126 TEST_P(SavedModelBundleWarmupUtilTest, NoWarmupDataFile) {
127  string base_path = io::JoinPath(testing::TmpDir(), "NoWarmupDataFile");
128  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
129  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
130 
131  SavedModelBundle saved_model_bundle;
132  AddSignatures(&saved_model_bundle.meta_graph_def);
133  TF_EXPECT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
134  [this](PredictionLog prediction_log) {
135  this->FakeRunWarmupRequest();
136  return absl::OkStatus();
137  }));
138  EXPECT_EQ(warmup_request_counter(), 0);
139 }
140 
141 TEST_P(SavedModelBundleWarmupUtilTest, WarmupDataFileEmpty) {
142  string base_path = io::JoinPath(testing::TmpDir(), "WarmupDataFileEmpty");
143  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
144  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
145  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
146  internal::WarmupConsts::kRequestsFileName);
147 
148  std::vector<string> warmup_records;
149  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 0));
150  SavedModelBundle saved_model_bundle;
151  AddSignatures(&saved_model_bundle.meta_graph_def);
152  TF_EXPECT_OK(RunSavedModelWarmup(CreateModelWarmupOptions(), base_path,
153  [this](PredictionLog prediction_log) {
154  this->FakeRunWarmupRequest();
155  return absl::OkStatus();
156  }));
157  EXPECT_EQ(warmup_request_counter(), 0);
158 }
159 
160 TEST_P(SavedModelBundleWarmupUtilTest, UnsupportedFileFormat) {
161  string base_path = io::JoinPath(testing::TmpDir(), "UnsupportedFileFormat");
162  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
163  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
164  const string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
165  internal::WarmupConsts::kRequestsFileName);
166 
167  std::vector<string> warmup_records;
168  // Add unsupported log type
169  PredictionLog prediction_log;
170  TF_ASSERT_OK(
171  PopulatePredictionLog(&prediction_log, PredictionLog::kSessionRunLog));
172  warmup_records.push_back(prediction_log.SerializeAsString());
173 
174  TF_ASSERT_OK(WriteWarmupDataAsSerializedProtos(fname, warmup_records, 10));
175  SavedModelBundle saved_model_bundle;
176  AddSignatures(&saved_model_bundle.meta_graph_def);
177  const Status status = RunSavedModelWarmup(
178  CreateModelWarmupOptions(), base_path,
179  [](PredictionLog prediction_log) { return absl::OkStatus(); });
180  ASSERT_FALSE(status.ok());
181  EXPECT_EQ(::tensorflow::error::DATA_LOSS, status.code()) << status;
182  EXPECT_THAT(status.ToString(),
183  ::testing::HasSubstr(
184  "Please verify your warmup data is in TFRecord format"));
185 }
186 
187 TEST_P(SavedModelBundleWarmupUtilTest, TooManyWarmupRecords) {
188  string base_path = io::JoinPath(testing::TmpDir(), "TooManyWarmupRecords");
189  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
190  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
191  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
192  internal::WarmupConsts::kRequestsFileName);
193 
194  std::vector<string> warmup_records;
195  TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
196  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records,
197  internal::WarmupConsts::kMaxNumRecords + 1));
198  SavedModelBundle saved_model_bundle;
199  AddSignatures(&saved_model_bundle.meta_graph_def);
200  const Status status = RunSavedModelWarmup(
201  CreateModelWarmupOptions(), base_path,
202  [](PredictionLog prediction_log) { return absl::OkStatus(); });
203  ASSERT_FALSE(status.ok());
204  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
205  status.code())
206  << status;
207  EXPECT_THAT(
208  status.ToString(),
209  ::testing::HasSubstr("Number of warmup records exceeds the maximum"));
210 }
211 
212 TEST_P(SavedModelBundleWarmupUtilTest, UnparsableRecord) {
213  string base_path = io::JoinPath(testing::TmpDir(), "UnparsableRecord");
214  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
215  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
216  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
217  internal::WarmupConsts::kRequestsFileName);
218 
219  std::vector<string> warmup_records = {"malformed_record"};
220  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
221  SavedModelBundle saved_model_bundle;
222  const Status status = RunSavedModelWarmup(
223  CreateModelWarmupOptions(), base_path,
224  [](PredictionLog prediction_log) { return absl::OkStatus(); });
225  ASSERT_FALSE(status.ok());
226  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
227  status.code())
228  << status;
229  EXPECT_THAT(status.ToString(),
230  ::testing::HasSubstr("Failed to parse warmup record"));
231 }
232 
233 TEST_P(SavedModelBundleWarmupUtilTest, RunFailure) {
234  string base_path = io::JoinPath(testing::TmpDir(), "RunFailure");
235  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
236  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
237  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
238  internal::WarmupConsts::kRequestsFileName);
239 
240  int num_warmup_records = 10;
241  std::vector<string> warmup_records;
242  TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
243  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
244  SavedModelBundle saved_model_bundle;
245  AddSignatures(&saved_model_bundle.meta_graph_def);
246  Status status = RunSavedModelWarmup(
247  CreateModelWarmupOptions(), base_path, [](PredictionLog prediction_log) {
248  return errors::InvalidArgument("Run failed");
249  });
250  ASSERT_FALSE(status.ok());
251  EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
252  status.code())
253  << status;
254  EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Run failed"));
255 }
256 INSTANTIATE_TEST_SUITE_P(ParallelWarmUp, SavedModelBundleWarmupUtilTest,
257  ::testing::Bool());
258 } // namespace
259 } // namespace internal
260 } // namespace serving
261 } // namespace tensorflow