TensorFlow Serving C++ API Documentation
saved_model_warmup_test.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup.h"
17 
18 #include <vector>
19 
20 #include "google/protobuf/wrappers.pb.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/cc/saved_model/constants.h"
24 #include "tensorflow/cc/saved_model/signature_constants.h"
25 #include "tensorflow/core/example/example.pb.h"
26 #include "tensorflow/core/example/feature.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/lib/io/record_writer.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/threadpool_options.h"
35 #include "tensorflow_serving/apis/classification.pb.h"
36 #include "tensorflow_serving/apis/inference.pb.h"
37 #include "tensorflow_serving/apis/input.pb.h"
38 #include "tensorflow_serving/apis/model.pb.h"
39 #include "tensorflow_serving/apis/predict.pb.h"
40 #include "tensorflow_serving/apis/prediction_log.pb.h"
41 #include "tensorflow_serving/apis/regression.pb.h"
42 #include "tensorflow_serving/core/test_util/mock_session.h"
43 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
44 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
45 
46 namespace tensorflow {
47 namespace serving {
48 
49 namespace {
50 
51 using test_util::MockSession;
52 using ::testing::_;
53 using ::testing::DoAll;
54 using ::testing::Return;
55 using ::testing::SetArgPointee;
56 using ::testing::SizeIs;
57 
58 class SavedModelBundleWarmupOptionsTest
59  : public ::testing::TestWithParam<bool> {
60  public:
61  bool EnableNumRequestIterations() { return GetParam(); }
62 
63  ModelWarmupOptions GetModelWarmupOptions() {
64  ModelWarmupOptions options;
65  if (EnableNumRequestIterations()) {
66  options.mutable_num_request_iterations()->set_value(2);
67  }
68  return options;
69  }
70 
71  int GetNumRequestIterations() {
72  if (EnableNumRequestIterations()) {
73  return 2;
74  }
75  return 1;
76  }
77 };
78 
79 TEST_P(SavedModelBundleWarmupOptionsTest, MixedWarmupData) {
80  string base_path = io::JoinPath(testing::TmpDir(), "MixedWarmupData");
81  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
82  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
83  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
84  internal::WarmupConsts::kRequestsFileName);
85 
86  int num_warmup_records = 10;
87  std::vector<string> warmup_records;
88  TF_ASSERT_OK(AddMixedWarmupData(&warmup_records));
89  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, num_warmup_records));
90  SavedModelBundle saved_model_bundle;
91  AddSignatures(&saved_model_bundle.meta_graph_def);
92  MockSession* mock = new MockSession;
93  saved_model_bundle.session.reset(mock);
94  Tensor scores(DT_FLOAT, TensorShape({1, 1}));
95  Tensor classes(DT_STRING, TensorShape({1, 1}));
96  // Regress and Predict cases
97  EXPECT_CALL(*mock, Run(_, _, SizeIs(1), _, _, _, _))
98  .Times(num_warmup_records * 2 * GetNumRequestIterations())
99  .WillRepeatedly(DoAll(SetArgPointee<4>(std::vector<Tensor>({scores})),
100  Return(absl::OkStatus())));
101  // Classify case
102  EXPECT_CALL(*mock, Run(_, _, SizeIs(2), _, _, _, _))
103  .Times(num_warmup_records * GetNumRequestIterations())
104  .WillRepeatedly(
105  DoAll(SetArgPointee<4>(std::vector<Tensor>({classes, scores})),
106  Return(absl::OkStatus())));
107  // MultiInference case
108  EXPECT_CALL(*mock, Run(_, _, SizeIs(3), _, _, _, _))
109  .Times(num_warmup_records * GetNumRequestIterations())
110  .WillRepeatedly(DoAll(
111  SetArgPointee<4>(std::vector<Tensor>({classes, scores, scores})),
112  Return(absl::OkStatus())));
113  TF_EXPECT_OK(RunSavedModelWarmup(GetModelWarmupOptions(), RunOptions(),
114  base_path, &saved_model_bundle));
115 }
116 INSTANTIATE_TEST_SUITE_P(WarmupOptions, SavedModelBundleWarmupOptionsTest,
117  ::testing::Bool());
118 
119 TEST(SavedModelBundleWarmupTest, UnsupportedLogType_SessionRun) {
120  string base_path = io::JoinPath(testing::TmpDir(), "SessionRun");
121  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
122  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
123  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
124  internal::WarmupConsts::kRequestsFileName);
125 
126  std::vector<string> warmup_records;
127  // Add unsupported log type
128  TF_ASSERT_OK(AddToWarmupData(&warmup_records, PredictionLog::kSessionRunLog));
129  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
130  SavedModelBundle saved_model_bundle;
131  AddSignatures(&saved_model_bundle.meta_graph_def);
132  MockSession* mock = new MockSession;
133  saved_model_bundle.session.reset(mock);
134  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
135  .WillRepeatedly(Return(absl::OkStatus()));
136  const Status status = RunSavedModelWarmup(ModelWarmupOptions(), RunOptions(),
137  base_path, &saved_model_bundle);
138  ASSERT_FALSE(status.ok());
139  EXPECT_EQ(::tensorflow::error::UNIMPLEMENTED, status.code()) << status;
140  EXPECT_THAT(status.ToString(),
141  ::testing::HasSubstr("Unsupported log_type for warmup"));
142 }
143 
144 TEST(SavedModelBundleWarmupTest, UnsupportedLogType_PredictStreamed) {
145  string base_path = io::JoinPath(testing::TmpDir(), "PredictStreamed");
146  TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(
147  io::JoinPath(base_path, kSavedModelAssetsExtraDirectory)));
148  string fname = io::JoinPath(base_path, kSavedModelAssetsExtraDirectory,
149  internal::WarmupConsts::kRequestsFileName);
150 
151  std::vector<string> warmup_records;
152  // Add unsupported log type
153  TF_ASSERT_OK(
154  AddToWarmupData(&warmup_records, PredictionLog::kPredictStreamedLog));
155  TF_ASSERT_OK(WriteWarmupData(fname, warmup_records, 10));
156  SavedModelBundle saved_model_bundle;
157  AddSignatures(&saved_model_bundle.meta_graph_def);
158  MockSession* mock = new MockSession;
159  saved_model_bundle.session.reset(mock);
160  EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
161  .WillRepeatedly(Return(absl::OkStatus()));
162  const Status status = RunSavedModelWarmup(ModelWarmupOptions(), RunOptions(),
163  base_path, &saved_model_bundle);
164  ASSERT_FALSE(status.ok());
165  EXPECT_EQ(::tensorflow::error::UNIMPLEMENTED, status.code()) << status;
166  EXPECT_THAT(status.ToString(),
167  ::testing::HasSubstr("Unsupported log_type for warmup"));
168 }
169 
170 } // namespace
171 
172 } // namespace serving
173 } // namespace tensorflow