TensorFlow Serving C++ API Documentation
loader_harness_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/loader_harness.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "absl/status/status.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow_serving/core/loader.h"
30 #include "tensorflow_serving/core/servable_id.h"
31 #include "tensorflow_serving/core/test_util/mock_loader.h"
32 #include "tensorflow_serving/test_util/test_util.h"
33 #include "tensorflow_serving/util/any_ptr.h"
34 
35 namespace tensorflow {
36 namespace serving {
37 namespace {
38 
39 using ::testing::HasSubstr;
40 using ::testing::InvokeWithoutArgs;
41 using ::testing::NiceMock;
42 using ::testing::Return;
43 using ::testing::StrictMock;
44 
45 // Walks 'harness' through a sequence of transitions from kReady to kDisabled.
46 void QuiesceAndUnload(LoaderHarness* const harness) {
47  TF_ASSERT_OK(harness->UnloadRequested());
48  TF_ASSERT_OK(harness->StartQuiescing());
49  TF_ASSERT_OK(harness->DoneQuiescing());
50  TF_ASSERT_OK(harness->Unload());
51 }
52 
53 // Makes it s.t. it's legal to destruct 'harness'.
54 void EnableDestruction(LoaderHarness* const harness) {
55  harness->Error(errors::Unknown("Erroring servable on purpose for shutdown"));
56 }
57 
58 TEST(LoaderHarnessTest, Init) {
59  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
60  LoaderHarness harness(ServableId{"test", 0}, std::unique_ptr<Loader>(loader));
61 
62  EXPECT_EQ((ServableId{"test", 0}), harness.id());
63  EXPECT_EQ(LoaderHarness::State::kNew, harness.state());
64  EXPECT_EQ(harness.state(), harness.loader_state_snapshot<>().state);
65 }
66 
67 TEST(LoaderHarnessTest, LoadRequested) {
68  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
69  LoaderHarness harness(ServableId{"test", 0}, std::unique_ptr<Loader>(loader));
70 
71  TF_ASSERT_OK(harness.LoadRequested());
72  EXPECT_EQ(LoaderHarness::State::kLoadRequested, harness.state());
73 
74  harness.Error(
75  errors::Unknown("Transitions harness to a legally destructible state."));
76 }
77 
78 TEST(LoaderHarnessTest, Quiesce) {
79  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
80  const ServableId servable_id = {"test", 0};
81  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
82  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
83  .WillOnce(Return(OkStatus()));
84  EXPECT_CALL(*loader, Unload()).WillOnce(Return());
85 
86  TF_ASSERT_OK(harness.LoadRequested());
87  TF_ASSERT_OK(harness.LoadApproved());
88  TF_ASSERT_OK(harness.Load());
89 
90  TF_ASSERT_OK(harness.UnloadRequested());
91  TF_ASSERT_OK(harness.StartQuiescing());
92  EXPECT_EQ(LoaderHarness::State::kQuiescing, harness.state());
93 
94  TF_ASSERT_OK(harness.DoneQuiescing());
95  EXPECT_EQ(LoaderHarness::State::kQuiesced, harness.state());
96 
97  // Otherwise we break the dtor invariant and check-fail.
98  TF_ASSERT_OK(harness.Unload());
99 }
100 
101 TEST(LoaderHarnessTest, Load) {
102  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
103  const ServableId servable_id = {"test", 0};
104 
105  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
106 
107  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
108  .WillOnce(Return(OkStatus()));
109  {
110  std::unique_ptr<Thread> test_thread(
111  Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
112  TF_ASSERT_OK(harness.LoadRequested());
113  TF_ASSERT_OK(harness.LoadApproved());
114  EXPECT_TRUE(harness.Load().ok());
115  }));
116  // Deleting the thread here forces join and ensures that
117  // LoaderHarness::Load() returns.
118  }
119  EXPECT_EQ(LoaderHarness::State::kReady, harness.state());
120 
121  EnableDestruction(&harness);
122 }
123 
124 TEST(LoaderHarnessTest, Unload) {
125  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
126  const ServableId servable_id = {"test", 0};
127 
128  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
129  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
130  .WillOnce(Return(OkStatus()));
131  TF_ASSERT_OK(harness.LoadRequested());
132  TF_ASSERT_OK(harness.LoadApproved());
133  TF_ASSERT_OK(harness.Load());
134 
135  EXPECT_CALL(*loader, Unload()).WillOnce(Return());
136  {
137  std::unique_ptr<Thread> test_thread(Env::Default()->StartThread(
138  ThreadOptions(), "test", [&harness]() { QuiesceAndUnload(&harness); }));
139  // Deleting the thread here forces join and ensures that
140  // LoaderHarness::Unload() returns.
141  }
142  EXPECT_EQ(LoaderHarness::State::kDisabled, harness.state());
143 }
144 
145 TEST(LoaderHarnessTest, UnloadRequested) {
146  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
147  const ServableId servable_id = {"test", 0};
148  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
149  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
150  .WillOnce(Return(OkStatus()));
151  TF_ASSERT_OK(harness.LoadRequested());
152  TF_ASSERT_OK(harness.LoadApproved());
153  TF_ASSERT_OK(harness.Load());
154 
155  TF_ASSERT_OK(harness.UnloadRequested());
156  EXPECT_EQ(LoaderHarness::State::kUnloadRequested, harness.state());
157 
158  EnableDestruction(&harness);
159 }
160 
161 TEST(LoaderHarnessTest, LoadApproved) {
162  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
163  LoaderHarness harness(ServableId{"test", 0}, std::unique_ptr<Loader>(loader));
164 
165  TF_ASSERT_OK(harness.LoadRequested());
166  TF_ASSERT_OK(harness.LoadApproved());
167  EXPECT_EQ(LoaderHarness::State::kLoadApproved, harness.state());
168 
169  harness.Error(
170  errors::Unknown("Transitions harness to a legally destructible state."));
171 }
172 
173 TEST(LoaderHarnessTest, LoadError) {
174  test_util::MockLoader* loader = new StrictMock<test_util::MockLoader>;
175  const ServableId servable_id = {"test", 0};
176  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
177 
178  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
179  .WillOnce(Return(errors::Unknown("test load error")));
180  {
181  std::unique_ptr<Thread> test_thread(
182  Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
183  TF_ASSERT_OK(harness.LoadRequested());
184  TF_ASSERT_OK(harness.LoadApproved());
185  Status status = harness.Load();
186  EXPECT_THAT(status.message(), HasSubstr("test load error"));
187  }));
188  }
189  EXPECT_EQ(LoaderHarness::State::kError, harness.state());
190 }
191 
192 TEST(LoaderHarnessTest, ExternallySignalledError) {
193  LoaderHarness harness(ServableId{"test", 0}, nullptr /* no loader */);
194  EXPECT_EQ(LoaderHarness::State::kNew, harness.state());
195  const Status status =
196  Status(static_cast<tensorflow::errors::Code>(absl::StatusCode::kUnknown),
197  "Some unknown error");
198  harness.Error(status);
199  EXPECT_EQ(LoaderHarness::State::kError, harness.state());
200  EXPECT_EQ(status, harness.status());
201 }
202 
203 TEST(LoaderHarnessTest, ExternallySignalledErrorWithCallback) {
204  const ServableId id = {"test_servable", 42};
205  const Status error =
206  Status(static_cast<tensorflow::errors::Code>(absl::StatusCode::kUnknown),
207  "Some unknown error");
208  Notification callback_called;
209  LoaderHarness::Options options;
210  options.error_callback = [&](const ServableId& callback_id,
211  const Status& callback_error) {
212  EXPECT_EQ(id, callback_id);
213  EXPECT_EQ(callback_error, error);
214  callback_called.Notify();
215  };
216  LoaderHarness harness(id, nullptr /* no loader */, options);
217  harness.Error(error);
218  callback_called.WaitForNotification();
219 }
220 
221 TEST(LoaderHarnessTest, AdditionalState) {
222  std::unique_ptr<int> object(new int(10));
223  LoaderHarness harness({"test", 42}, nullptr, std::move(object));
224 
225  EXPECT_EQ(10, *harness.loader_state_snapshot<int>().additional_state);
226  EXPECT_EQ(10, *harness.additional_state<int>());
227  EXPECT_EQ(nullptr, harness.additional_state<float>());
228 }
229 
230 TEST(LoaderHarnessTest, NoAdditionalState) {
231  LoaderHarness harness({"test", 42}, nullptr);
232 
233  // Will return nullptr when there is no metadata set.
234  EXPECT_FALSE(harness.loader_state_snapshot<int>().additional_state);
235  EXPECT_EQ(nullptr, harness.additional_state<int>());
236  EXPECT_EQ(nullptr, harness.additional_state<float>());
237 }
238 
239 TEST(LoaderHarnessTest, MultipleLoadRequestsOnlyFirstOneSucceeds) {
240  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
241  LoaderHarness harness(ServableId{"test", 0}, std::unique_ptr<Loader>(loader));
242 
243  TF_ASSERT_OK(harness.LoadRequested());
244  const Status second_request_status = harness.LoadRequested();
245  EXPECT_FALSE(second_request_status.ok());
246  EXPECT_EQ(error::FAILED_PRECONDITION, second_request_status.code());
247  EXPECT_THAT(second_request_status.message(),
248  HasSubstr("Duplicate load request"));
249 
250  EnableDestruction(&harness);
251 }
252 
253 TEST(LoaderHarnessTest, MultipleUnloadRequestsOnlyFirstOneSucceeds) {
254  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
255  const ServableId servable_id = {"test", 0};
256  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
257 
258  TF_ASSERT_OK(harness.LoadRequested());
259  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
260  .WillOnce(Return(OkStatus()));
261  TF_ASSERT_OK(harness.LoadApproved());
262  TF_ASSERT_OK(harness.Load());
263 
264  TF_ASSERT_OK(harness.UnloadRequested());
265  const Status second_status = harness.UnloadRequested();
266  EXPECT_FALSE(second_status.ok());
267  EXPECT_EQ(error::FAILED_PRECONDITION, second_status.code());
268  EXPECT_THAT(
269  second_status.message(),
270  HasSubstr("Servable not loaded, or unload already requested/ongoing"));
271 
272  EnableDestruction(&harness);
273 }
274 
275 TEST(LoaderHarnessTest, RetryOnLoadErrorFinallySucceeds) {
276  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
277  LoaderHarness::Options options;
278  options.max_num_load_retries = 1;
279  options.load_retry_interval_micros = 1;
280  const ServableId servable_id = {"test", 0};
281  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
282 
283  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
284  .WillOnce(InvokeWithoutArgs(
285  []() { return errors::Unknown("test load error"); }))
286  .WillOnce(InvokeWithoutArgs([]() { return OkStatus(); }));
287  TF_ASSERT_OK(harness.LoadRequested());
288  TF_ASSERT_OK(harness.LoadApproved());
289  TF_ASSERT_OK(harness.Load());
290 
291  EnableDestruction(&harness);
292 }
293 
294 TEST(LoaderHarnessTest, RetryOnLoadErrorFinallyFails) {
295  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
296  LoaderHarness::Options options;
297  options.max_num_load_retries = 1;
298  options.load_retry_interval_micros = 0;
299  const ServableId servable_id = {"test", 0};
300  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
301 
302  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
303  .Times(2)
304  .WillRepeatedly(InvokeWithoutArgs(
305  []() { return errors::Unknown("test load error"); }));
306  TF_ASSERT_OK(harness.LoadRequested());
307  TF_ASSERT_OK(harness.LoadApproved());
308  const Status status = harness.Load();
309  EXPECT_THAT(status.message(), HasSubstr("test load error"));
310 }
311 
312 // Tests cancelling load retries.
313 TEST(LoaderHarnessTest, RetryOnLoadErrorCancelledLoad) {
314  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
315  LoaderHarness::Options options;
316  options.max_num_load_retries = 10;
317  options.load_retry_interval_micros = 0;
318  const ServableId servable_id = {"test", 0};
319  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
320 
321  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
322  .WillOnce(InvokeWithoutArgs(
323  []() { return errors::Unknown("test load error"); }))
324  // If the load is called again, we return OkStatus() to fail the test.
325  .WillRepeatedly(InvokeWithoutArgs([]() { return OkStatus(); }));
326  std::unique_ptr<Thread> test_thread(
327  Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
328  TF_ASSERT_OK(harness.LoadRequested());
329  TF_ASSERT_OK(harness.LoadApproved());
330  harness.set_should_retry([](absl::Status status) { return false; });
331  const Status status = harness.Load();
332  EXPECT_THAT(status.message(), HasSubstr("test load error"));
333  }));
334 }
335 
336 // Tests unload when ongoing load is cancelled.
337 TEST(LoaderHarnessTest, UnloadDueToCancelledLoad) {
338  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
339 
340  const ServableId servable_id = {"test", 0};
341  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
342 
343  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
344  .WillOnce(InvokeWithoutArgs([]() {
345  Env::Default()->SleepForMicroseconds(1000000);
346  return OkStatus();
347  }));
348 
349  std::unique_ptr<Thread> test_thread(
350  Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
351  TF_ASSERT_OK(harness.LoadRequested());
352  TF_ASSERT_OK(harness.LoadApproved());
353  harness.set_should_retry([](absl::Status status) { return false; });
354  const Status status = harness.Load();
355  EXPECT_THAT(status.message(), HasSubstr("cancelled"));
356  }));
357 }
358 
359 TEST(LoaderHarnessTest, UnloadDueToNonRetriableError) {
360  test_util::MockLoader* loader = new NiceMock<test_util::MockLoader>;
361 
362  const ServableId servable_id = {"test", 0};
363  LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
364 
365  EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
366  .WillOnce(Return(absl::InvalidArgumentError("Non-retriable error.")))
367  .WillRepeatedly(InvokeWithoutArgs([]() {
368  Env::Default()->SleepForMicroseconds(1000000);
369  return absl::OkStatus();
370  }));
371 
372  std::unique_ptr<Thread> test_thread(
373  Env::Default()->StartThread(ThreadOptions(), "test", [&harness]() {
374  TF_ASSERT_OK(harness.LoadRequested());
375  TF_ASSERT_OK(harness.LoadApproved());
376  harness.set_should_retry([](absl::Status status) {
377  return !absl::IsInvalidArgument(status);
378  });
379  const absl::Status status = harness.Load();
380  EXPECT_THAT(status.message(), HasSubstr("Non-retriable error."));
381  }));
382 }
383 
384 } // namespace
385 } // namespace serving
386 } // namespace tensorflow
@ kLoadRequested
The manager has been requested to load this servable.
@ kReady
'loader_->Load()' has succeeded.
@ kUnloadRequested
The manager has been requested to unload this servable.
@ kDisabled
'loader_->Unload()' has finished.
@ kQuiesced
The servable has been made unavailable for serving.
@ kQuiescing
The servable is going to be made unavailable for serving.