16 #include "tensorflow_serving/core/loader_harness.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "absl/status/status.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow_serving/core/loader.h"
30 #include "tensorflow_serving/core/servable_id.h"
31 #include "tensorflow_serving/core/test_util/mock_loader.h"
32 #include "tensorflow_serving/test_util/test_util.h"
33 #include "tensorflow_serving/util/any_ptr.h"
35 namespace tensorflow {
39 using ::testing::HasSubstr;
40 using ::testing::InvokeWithoutArgs;
41 using ::testing::NiceMock;
42 using ::testing::Return;
43 using ::testing::StrictMock;
46 void QuiesceAndUnload(LoaderHarness*
const harness) {
47 TF_ASSERT_OK(harness->UnloadRequested());
48 TF_ASSERT_OK(harness->StartQuiescing());
49 TF_ASSERT_OK(harness->DoneQuiescing());
50 TF_ASSERT_OK(harness->Unload());
54 void EnableDestruction(LoaderHarness*
const harness) {
55 harness->Error(errors::Unknown(
"Erroring servable on purpose for shutdown"));
58 TEST(LoaderHarnessTest, Init) {
59 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
60 LoaderHarness harness(ServableId{
"test", 0}, std::unique_ptr<Loader>(loader));
62 EXPECT_EQ((ServableId{
"test", 0}), harness.id());
64 EXPECT_EQ(harness.state(), harness.loader_state_snapshot<>().state);
67 TEST(LoaderHarnessTest, LoadRequested) {
68 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
69 LoaderHarness harness(ServableId{
"test", 0}, std::unique_ptr<Loader>(loader));
71 TF_ASSERT_OK(harness.LoadRequested());
75 errors::Unknown(
"Transitions harness to a legally destructible state."));
78 TEST(LoaderHarnessTest, Quiesce) {
79 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
80 const ServableId servable_id = {
"test", 0};
81 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
82 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
83 .WillOnce(Return(OkStatus()));
84 EXPECT_CALL(*loader, Unload()).WillOnce(Return());
86 TF_ASSERT_OK(harness.LoadRequested());
87 TF_ASSERT_OK(harness.LoadApproved());
88 TF_ASSERT_OK(harness.Load());
90 TF_ASSERT_OK(harness.UnloadRequested());
91 TF_ASSERT_OK(harness.StartQuiescing());
94 TF_ASSERT_OK(harness.DoneQuiescing());
98 TF_ASSERT_OK(harness.Unload());
101 TEST(LoaderHarnessTest, Load) {
102 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
103 const ServableId servable_id = {
"test", 0};
105 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
107 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
108 .WillOnce(Return(OkStatus()));
110 std::unique_ptr<Thread> test_thread(
111 Env::Default()->StartThread(ThreadOptions(),
"test", [&harness]() {
112 TF_ASSERT_OK(harness.LoadRequested());
113 TF_ASSERT_OK(harness.LoadApproved());
114 EXPECT_TRUE(harness.Load().ok());
121 EnableDestruction(&harness);
124 TEST(LoaderHarnessTest, Unload) {
125 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
126 const ServableId servable_id = {
"test", 0};
128 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
129 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
130 .WillOnce(Return(OkStatus()));
131 TF_ASSERT_OK(harness.LoadRequested());
132 TF_ASSERT_OK(harness.LoadApproved());
133 TF_ASSERT_OK(harness.Load());
135 EXPECT_CALL(*loader, Unload()).WillOnce(Return());
137 std::unique_ptr<Thread> test_thread(Env::Default()->StartThread(
138 ThreadOptions(),
"test", [&harness]() { QuiesceAndUnload(&harness); }));
145 TEST(LoaderHarnessTest, UnloadRequested) {
146 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
147 const ServableId servable_id = {
"test", 0};
148 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
149 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
150 .WillOnce(Return(OkStatus()));
151 TF_ASSERT_OK(harness.LoadRequested());
152 TF_ASSERT_OK(harness.LoadApproved());
153 TF_ASSERT_OK(harness.Load());
155 TF_ASSERT_OK(harness.UnloadRequested());
158 EnableDestruction(&harness);
161 TEST(LoaderHarnessTest, LoadApproved) {
162 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
163 LoaderHarness harness(ServableId{
"test", 0}, std::unique_ptr<Loader>(loader));
165 TF_ASSERT_OK(harness.LoadRequested());
166 TF_ASSERT_OK(harness.LoadApproved());
170 errors::Unknown(
"Transitions harness to a legally destructible state."));
173 TEST(LoaderHarnessTest, LoadError) {
174 test_util::MockLoader* loader =
new StrictMock<test_util::MockLoader>;
175 const ServableId servable_id = {
"test", 0};
176 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
178 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
179 .WillOnce(Return(errors::Unknown(
"test load error")));
181 std::unique_ptr<Thread> test_thread(
182 Env::Default()->StartThread(ThreadOptions(),
"test", [&harness]() {
183 TF_ASSERT_OK(harness.LoadRequested());
184 TF_ASSERT_OK(harness.LoadApproved());
185 Status status = harness.Load();
186 EXPECT_THAT(status.message(), HasSubstr(
"test load error"));
192 TEST(LoaderHarnessTest, ExternallySignalledError) {
193 LoaderHarness harness(ServableId{
"test", 0},
nullptr );
195 const Status status =
196 Status(
static_cast<tensorflow::errors::Code
>(absl::StatusCode::kUnknown),
197 "Some unknown error");
198 harness.Error(status);
200 EXPECT_EQ(status, harness.status());
203 TEST(LoaderHarnessTest, ExternallySignalledErrorWithCallback) {
204 const ServableId
id = {
"test_servable", 42};
206 Status(
static_cast<tensorflow::errors::Code
>(absl::StatusCode::kUnknown),
207 "Some unknown error");
208 Notification callback_called;
209 LoaderHarness::Options options;
210 options.error_callback = [&](
const ServableId& callback_id,
211 const Status& callback_error) {
212 EXPECT_EQ(
id, callback_id);
213 EXPECT_EQ(callback_error, error);
214 callback_called.Notify();
216 LoaderHarness harness(
id,
nullptr , options);
217 harness.Error(error);
218 callback_called.WaitForNotification();
221 TEST(LoaderHarnessTest, AdditionalState) {
222 std::unique_ptr<int> object(
new int(10));
223 LoaderHarness harness({
"test", 42},
nullptr, std::move(
object));
225 EXPECT_EQ(10, *harness.loader_state_snapshot<
int>().additional_state);
226 EXPECT_EQ(10, *harness.additional_state<
int>());
227 EXPECT_EQ(
nullptr, harness.additional_state<
float>());
230 TEST(LoaderHarnessTest, NoAdditionalState) {
231 LoaderHarness harness({
"test", 42},
nullptr);
234 EXPECT_FALSE(harness.loader_state_snapshot<
int>().additional_state);
235 EXPECT_EQ(
nullptr, harness.additional_state<
int>());
236 EXPECT_EQ(
nullptr, harness.additional_state<
float>());
239 TEST(LoaderHarnessTest, MultipleLoadRequestsOnlyFirstOneSucceeds) {
240 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
241 LoaderHarness harness(ServableId{
"test", 0}, std::unique_ptr<Loader>(loader));
243 TF_ASSERT_OK(harness.LoadRequested());
244 const Status second_request_status = harness.LoadRequested();
245 EXPECT_FALSE(second_request_status.ok());
246 EXPECT_EQ(error::FAILED_PRECONDITION, second_request_status.code());
247 EXPECT_THAT(second_request_status.message(),
248 HasSubstr(
"Duplicate load request"));
250 EnableDestruction(&harness);
253 TEST(LoaderHarnessTest, MultipleUnloadRequestsOnlyFirstOneSucceeds) {
254 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
255 const ServableId servable_id = {
"test", 0};
256 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
258 TF_ASSERT_OK(harness.LoadRequested());
259 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
260 .WillOnce(Return(OkStatus()));
261 TF_ASSERT_OK(harness.LoadApproved());
262 TF_ASSERT_OK(harness.Load());
264 TF_ASSERT_OK(harness.UnloadRequested());
265 const Status second_status = harness.UnloadRequested();
266 EXPECT_FALSE(second_status.ok());
267 EXPECT_EQ(error::FAILED_PRECONDITION, second_status.code());
269 second_status.message(),
270 HasSubstr(
"Servable not loaded, or unload already requested/ongoing"));
272 EnableDestruction(&harness);
275 TEST(LoaderHarnessTest, RetryOnLoadErrorFinallySucceeds) {
276 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
277 LoaderHarness::Options options;
278 options.max_num_load_retries = 1;
279 options.load_retry_interval_micros = 1;
280 const ServableId servable_id = {
"test", 0};
281 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
283 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
284 .WillOnce(InvokeWithoutArgs(
285 []() {
return errors::Unknown(
"test load error"); }))
286 .WillOnce(InvokeWithoutArgs([]() {
return OkStatus(); }));
287 TF_ASSERT_OK(harness.LoadRequested());
288 TF_ASSERT_OK(harness.LoadApproved());
289 TF_ASSERT_OK(harness.Load());
291 EnableDestruction(&harness);
294 TEST(LoaderHarnessTest, RetryOnLoadErrorFinallyFails) {
295 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
296 LoaderHarness::Options options;
297 options.max_num_load_retries = 1;
298 options.load_retry_interval_micros = 0;
299 const ServableId servable_id = {
"test", 0};
300 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
302 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
304 .WillRepeatedly(InvokeWithoutArgs(
305 []() {
return errors::Unknown(
"test load error"); }));
306 TF_ASSERT_OK(harness.LoadRequested());
307 TF_ASSERT_OK(harness.LoadApproved());
308 const Status status = harness.Load();
309 EXPECT_THAT(status.message(), HasSubstr(
"test load error"));
313 TEST(LoaderHarnessTest, RetryOnLoadErrorCancelledLoad) {
314 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
315 LoaderHarness::Options options;
316 options.max_num_load_retries = 10;
317 options.load_retry_interval_micros = 0;
318 const ServableId servable_id = {
"test", 0};
319 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader), options);
321 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
322 .WillOnce(InvokeWithoutArgs(
323 []() {
return errors::Unknown(
"test load error"); }))
325 .WillRepeatedly(InvokeWithoutArgs([]() {
return OkStatus(); }));
326 std::unique_ptr<Thread> test_thread(
327 Env::Default()->StartThread(ThreadOptions(),
"test", [&harness]() {
328 TF_ASSERT_OK(harness.LoadRequested());
329 TF_ASSERT_OK(harness.LoadApproved());
330 harness.set_should_retry([](absl::Status status) {
return false; });
331 const Status status = harness.Load();
332 EXPECT_THAT(status.message(), HasSubstr(
"test load error"));
337 TEST(LoaderHarnessTest, UnloadDueToCancelledLoad) {
338 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
340 const ServableId servable_id = {
"test", 0};
341 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
343 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
344 .WillOnce(InvokeWithoutArgs([]() {
345 Env::Default()->SleepForMicroseconds(1000000);
349 std::unique_ptr<Thread> test_thread(
350 Env::Default()->StartThread(ThreadOptions(),
"test", [&harness]() {
351 TF_ASSERT_OK(harness.LoadRequested());
352 TF_ASSERT_OK(harness.LoadApproved());
353 harness.set_should_retry([](absl::Status status) {
return false; });
354 const Status status = harness.Load();
355 EXPECT_THAT(status.message(), HasSubstr(
"cancelled"));
359 TEST(LoaderHarnessTest, UnloadDueToNonRetriableError) {
360 test_util::MockLoader* loader =
new NiceMock<test_util::MockLoader>;
362 const ServableId servable_id = {
"test", 0};
363 LoaderHarness harness(servable_id, std::unique_ptr<Loader>(loader));
365 EXPECT_CALL(*loader, LoadWithMetadata(Loader::Metadata{servable_id}))
366 .WillOnce(Return(absl::InvalidArgumentError(
"Non-retriable error.")))
367 .WillRepeatedly(InvokeWithoutArgs([]() {
368 Env::Default()->SleepForMicroseconds(1000000);
369 return absl::OkStatus();
372 std::unique_ptr<Thread> test_thread(
373 Env::Default()->StartThread(ThreadOptions(),
"test", [&harness]() {
374 TF_ASSERT_OK(harness.LoadRequested());
375 TF_ASSERT_OK(harness.LoadApproved());
376 harness.set_should_retry([](absl::Status status) {
377 return !absl::IsInvalidArgument(status);
379 const absl::Status status = harness.Load();
380 EXPECT_THAT(status.message(), HasSubstr(
"Non-retriable error."));
@ kLoadRequested
The manager has been requested to load this servable.
@ kReady
'loader_->Load()' has succeeded.
@ kUnloadRequested
The manager has been requested to unload this servable.
@ kDisabled
'loader_->Unload()' has finished.
@ kQuiesced
The servable has been made unavailable for serving.
@ kQuiescing
The servable is going to be made unavailable for serving.