16 #include "tensorflow_serving/core/simple_loader.h"
22 #include <gtest/gtest.h>
23 #include "absl/memory/memory.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow_serving/core/servable_data.h"
30 #include "tensorflow_serving/core/servable_id.h"
31 #include "tensorflow_serving/test_util/test_util.h"
33 namespace tensorflow {
37 using test_util::CreateProto;
38 using test_util::EqualsProto;
51 explicit Caller(State* state) : state_(state) { *state_ = State::kCtor; }
52 void DoStuff() { *state_ = State::kDoStuff; }
53 ~Caller() { *state_ = State::kDtor; }
59 Loader::Metadata CreateMetadata() {
return {ServableId{
"name", 42}}; }
61 class LoaderCreatorWithoutMetadata {
63 template <
typename ServableType,
typename... Args>
64 static std::unique_ptr<Loader> CreateSimpleLoader(
65 typename SimpleLoader<ServableType>::Creator creator, Args... args) {
66 return absl::make_unique<SimpleLoader<ServableType>>(creator, args...);
69 static Status Load(Loader* loader) {
return loader->Load(); }
72 class LoaderCreatorWithMetadata {
74 template <
typename ServableType,
typename... Args>
75 static std::unique_ptr<Loader> CreateSimpleLoader(
76 typename SimpleLoader<ServableType>::Creator creator, Args... args) {
77 return absl::make_unique<SimpleLoader<ServableType>>(
78 [creator](
const Loader::Metadata& metadata,
79 std::unique_ptr<ServableType>* servable) {
80 const auto& expected_metadata = CreateMetadata();
81 EXPECT_EQ(expected_metadata.servable_id, metadata.servable_id);
82 return creator(servable);
87 static Status Load(Loader* loader) {
88 return loader->LoadWithMetadata(CreateMetadata());
93 class SimpleLoaderTest :
public ::testing::Test {};
94 using LoaderCreatorTypes =
95 ::testing::Types<LoaderCreatorWithoutMetadata, LoaderCreatorWithMetadata>;
96 TYPED_TEST_SUITE(SimpleLoaderTest, LoaderCreatorTypes);
100 TYPED_TEST(SimpleLoaderTest, VerifyServableStates) {
101 State state = State::kNone;
102 auto loader = TypeParam::template CreateSimpleLoader<Caller>(
103 [&state](std::unique_ptr<Caller>* caller) {
104 caller->reset(
new Caller(&state));
107 SimpleLoader<Caller>::EstimateNoResources());
108 EXPECT_EQ(State::kNone, state);
109 const Status status = TypeParam::Load(loader.get());
110 TF_EXPECT_OK(status);
111 EXPECT_EQ(State::kCtor, state);
112 AnyPtr servable = loader->servable();
113 ASSERT_TRUE(servable.get<Caller>() !=
nullptr);
114 servable.get<Caller>()->DoStuff();
115 EXPECT_EQ(State::kDoStuff, state);
117 EXPECT_EQ(State::kDtor, state);
118 state = State::kNone;
119 loader.reset(
nullptr);
120 EXPECT_EQ(State::kNone, state);
123 TYPED_TEST(SimpleLoaderTest, ResourceEstimation) {
124 const auto want = CreateProto<ResourceAllocation>(
125 "resource_quantities { "
128 " kind: 'processing' "
132 auto loader = TypeParam::template CreateSimpleLoader<int>(
133 [](std::unique_ptr<int>* servable) {
134 servable->reset(
new int);
137 [&want](ResourceAllocation* estimate) {
143 ResourceAllocation got;
144 TF_ASSERT_OK(loader->EstimateResources(&got));
145 EXPECT_THAT(got, EqualsProto(want));
149 TF_ASSERT_OK(TypeParam::Load(loader.get()));
151 ResourceAllocation got;
152 TF_ASSERT_OK(loader->EstimateResources(&got));
153 EXPECT_THAT(got, EqualsProto(want));
157 TYPED_TEST(SimpleLoaderTest, ResourceEstimationWithPostLoadRelease) {
158 const auto pre_load_resources = CreateProto<ResourceAllocation>(
159 "resource_quantities { "
162 " kind: 'processing' "
166 const auto post_load_resources = CreateProto<ResourceAllocation>(
167 "resource_quantities { "
170 " kind: 'processing' "
174 auto loader = TypeParam::template CreateSimpleLoader<int>(
175 [](std::unique_ptr<int>* servable) {
176 servable->reset(
new int);
179 [&pre_load_resources](ResourceAllocation* estimate) {
180 *estimate = pre_load_resources;
183 absl::make_optional([&post_load_resources](ResourceAllocation* estimate) {
184 *estimate = post_load_resources;
189 for (
int i = 0; i < 2; ++i) {
190 ResourceAllocation got;
191 TF_ASSERT_OK(loader->EstimateResources(&got));
192 EXPECT_THAT(got, EqualsProto(pre_load_resources));
196 TF_ASSERT_OK(TypeParam::Load(loader.get()));
198 ResourceAllocation got;
199 TF_ASSERT_OK(loader->EstimateResources(&got));
200 EXPECT_THAT(got, EqualsProto(post_load_resources));
206 TYPED_TEST(SimpleLoaderTest, LoadError) {
207 auto loader = TypeParam::template CreateSimpleLoader<Caller>(
208 [](std::unique_ptr<Caller>* caller) {
209 return errors::InvalidArgument(
"No way!");
211 SimpleLoader<Caller>::EstimateNoResources());
212 const Status status = TypeParam::Load(loader.get());
213 EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
214 EXPECT_EQ(
"No way!", status.message());
217 TEST(SimpleLoaderCompatibilityTest, WithoutMetadata) {
218 auto loader_without_metadata = absl::make_unique<SimpleLoader<int>>(
219 [](std::unique_ptr<int>* servable) {
220 servable->reset(
new int);
223 SimpleLoader<int>::EstimateNoResources());
226 TF_EXPECT_OK(loader_without_metadata->Load());
227 TF_EXPECT_OK(loader_without_metadata->LoadWithMetadata(CreateMetadata()));
230 TEST(SimpleLoaderCompatibilityTest, WithMetadata) {
231 auto loader_with_metadata = absl::make_unique<SimpleLoader<int>>(
232 [](
const Loader::Metadata& metadata, std::unique_ptr<int>* servable) {
233 const auto& expected_metadata = CreateMetadata();
234 EXPECT_EQ(expected_metadata.servable_id, metadata.servable_id);
235 servable->reset(
new int);
238 SimpleLoader<int>::EstimateNoResources());
241 const Status error_status = loader_with_metadata->Load();
242 EXPECT_EQ(error::FAILED_PRECONDITION, error_status.code());
243 TF_EXPECT_OK(loader_with_metadata->LoadWithMetadata(CreateMetadata()));
248 template <
typename DataType,
typename ServableType>
249 class SimpleLoaderSourceAdapterImpl final
250 :
public SimpleLoaderSourceAdapter<DataType, ServableType> {
252 SimpleLoaderSourceAdapterImpl(
253 typename SimpleLoaderSourceAdapter<DataType, ServableType>::Creator
255 typename SimpleLoaderSourceAdapter<
256 DataType, ServableType>::ResourceEstimator resource_estimator)
257 : SimpleLoaderSourceAdapter<DataType, ServableType>(creator,
258 resource_estimator) {}
259 ~SimpleLoaderSourceAdapterImpl()
override { TargetBase<DataType>::Detach(); }
262 TEST(SimpleLoaderSourceAdapterTest, Basic) {
263 SimpleLoaderSourceAdapterImpl<string, string> adapter(
264 [](
const string& data, std::unique_ptr<string>* servable) {
265 servable->reset(
new string);
266 **servable = strings::StrCat(data,
"_was_here");
269 [](
const string& data, ResourceAllocation* output) {
270 ResourceAllocation::Entry* entry = output->add_resource_quantities();
271 entry->mutable_resource()->set_device(data);
272 entry->set_quantity(42);
276 const string kServableName =
"test_servable_name";
277 bool callback_called;
278 adapter.SetAspiredVersionsCallback(
279 [&](
const StringPiece servable_name,
280 std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
281 callback_called =
true;
282 EXPECT_EQ(kServableName, servable_name);
283 EXPECT_EQ(1, versions.size());
284 TF_ASSERT_OK(versions[0].status());
285 std::unique_ptr<Loader> loader = versions[0].ConsumeDataOrDie();
286 ResourceAllocation estimate_given;
287 TF_ASSERT_OK(loader->EstimateResources(&estimate_given));
288 EXPECT_THAT(estimate_given, EqualsProto(CreateProto<ResourceAllocation>(
289 "resource_quantities { "
291 " device: 'test_data' "
295 TF_ASSERT_OK(loader->Load());
296 AnyPtr servable = loader->servable();
297 ASSERT_TRUE(servable.get<
string>() !=
nullptr);
298 EXPECT_EQ(
"test_data_was_here", *servable.get<
string>());
300 adapter.SetAspiredVersions(
301 kServableName, {ServableData<string>({kServableName, 0},
"test_data")});
302 EXPECT_TRUE(callback_called);
307 TEST(SimpleLoaderSourceAdapterTest, OkayToDeleteAdapter) {
308 std::unique_ptr<Loader> loader;
311 auto adapter = std::unique_ptr<SimpleLoaderSourceAdapter<string, string>>(
312 new SimpleLoaderSourceAdapterImpl<string, string>(
313 [](
const string& data, std::unique_ptr<string>* servable) {
314 servable->reset(
new string);
315 **servable = strings::StrCat(data,
"_was_here");
318 SimpleLoaderSourceAdapter<string, string>::EstimateNoResources()));
320 const string kServableName =
"test_servable_name";
321 adapter->SetAspiredVersionsCallback(
322 [&](
const StringPiece servable_name,
323 std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
324 ASSERT_EQ(1, versions.size());
325 TF_ASSERT_OK(versions[0].status());
326 loader = versions[0].ConsumeDataOrDie();
328 adapter->SetAspiredVersions(
329 kServableName, {ServableData<string>({kServableName, 0},
"test_data")});
336 ResourceAllocation estimate_given;
337 TF_ASSERT_OK(loader->EstimateResources(&estimate_given));
338 TF_ASSERT_OK(loader->Load());