16 #include "tensorflow_serving/servables/tensorflow/util.h"
22 #include "google/protobuf/wrappers.pb.h"
23 #include "tensorflow/cc/saved_model/signature_constants.h"
24 #include "tensorflow/core/example/example.pb.h"
25 #include "tensorflow/core/example/feature.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_testutil.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/histogram/histogram.h"
30 #include "tensorflow/core/lib/monitoring/counter.h"
31 #include "tensorflow/core/lib/monitoring/sampler.h"
32 #include "tensorflow/core/platform/path.h"
33 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
34 #include "tensorflow_serving/test_util/test_util.h"
35 #include "tensorflow_serving/util/test_util/mock_file_probing_env.h"
37 namespace tensorflow {
41 using test_util::EqualsProto;
43 using ::testing::DoAll;
45 using ::testing::HasSubstr;
46 using ::testing::Return;
47 using ::testing::SetArgPointee;
49 class InputUtilTest :
public ::testing::Test {
54 feature.mutable_int64_list()->add_value(11);
56 (*example.mutable_features()->mutable_feature())[
"a"] = feature;
62 feature.mutable_int64_list()->add_value(22);
64 (*example.mutable_features()->mutable_feature())[
"b"] = feature;
68 Example example_C(
const int64_t value = 33) {
70 feature.mutable_int64_list()->add_value(value);
72 (*example.mutable_features()->mutable_feature())[
"c"] = feature;
80 TEST_F(InputUtilTest, Empty_KindNotSet) {
81 const Status status = InputToSerializedExampleTensor(input_, &tensor_);
82 ASSERT_FALSE(status.ok());
83 EXPECT_THAT(status.message(), HasSubstr(
"Input is empty"));
86 TEST_F(InputUtilTest, Empty_ExampleList) {
87 input_.mutable_example_list();
89 const Status status = InputToSerializedExampleTensor(input_, &tensor_);
90 ASSERT_FALSE(status.ok());
91 EXPECT_THAT(status.message(), HasSubstr(
"Input is empty"));
94 TEST_F(InputUtilTest, Empty_ExampleListWithContext) {
95 input_.mutable_example_list_with_context();
97 const Status status = InputToSerializedExampleTensor(input_, &tensor_);
98 ASSERT_FALSE(status.ok());
99 EXPECT_THAT(status.message(), HasSubstr(
"Input is empty"));
102 TEST_F(InputUtilTest, ExampleList) {
103 *input_.mutable_example_list()->mutable_examples()->Add() = example_A();
104 *input_.mutable_example_list()->mutable_examples()->Add() = example_B();
106 TF_ASSERT_OK(InputToSerializedExampleTensor(input_, &tensor_));
107 EXPECT_EQ(2, tensor_.NumElements());
108 const auto vec = tensor_.flat<tstring>();
109 ASSERT_EQ(vec.size(), 2);
110 Example serialized_example;
111 ASSERT_TRUE(serialized_example.ParseFromString(vec(0)));
112 EXPECT_THAT(serialized_example, EqualsProto(example_A()));
113 ASSERT_TRUE(serialized_example.ParseFromString(vec(1)));
114 EXPECT_THAT(serialized_example, EqualsProto(example_B()));
117 TEST_F(InputUtilTest, ExampleListWithContext) {
119 input_.mutable_example_list_with_context()->mutable_examples();
120 *examples->Add() = example_A();
121 *examples->Add() = example_B();
122 *input_.mutable_example_list_with_context()->mutable_context() = example_C();
124 TF_ASSERT_OK(InputToSerializedExampleTensor(input_, &tensor_));
125 EXPECT_EQ(2, tensor_.NumElements());
126 const auto vec = tensor_.flat<tstring>();
127 ASSERT_EQ(vec.size(), 2);
129 Example serialized_example;
130 ASSERT_TRUE(serialized_example.ParseFromString(vec(0)));
131 EXPECT_THAT(serialized_example.features().feature().at(
"c"),
132 EqualsProto(example_C().features().feature().at(
"c")));
133 EXPECT_THAT(serialized_example.features().feature().at(
"a"),
134 EqualsProto(example_A().features().feature().at(
"a")));
137 Example serialized_example;
138 ASSERT_TRUE(serialized_example.ParseFromString(vec(1)));
139 EXPECT_THAT(serialized_example.features().feature().at(
"c"),
140 EqualsProto(example_C().features().feature().at(
"c")));
141 EXPECT_THAT(serialized_example.features().feature().at(
"b"),
142 EqualsProto(example_B().features().feature().at(
"b")));
147 TEST_F(InputUtilTest, ExampleListWithOverridingContext) {
149 input_.mutable_example_list_with_context()->mutable_examples();
150 *examples->Add() = example_A();
151 *examples->Add() = example_C(64);
152 *input_.mutable_example_list_with_context()->mutable_context() = example_C();
154 TF_ASSERT_OK(InputToSerializedExampleTensor(input_, &tensor_));
155 EXPECT_EQ(2, tensor_.NumElements());
156 const auto vec = tensor_.flat<tstring>();
157 ASSERT_EQ(vec.size(), 2);
159 Example serialized_example;
160 ASSERT_TRUE(serialized_example.ParseFromString(vec(0)));
161 EXPECT_THAT(serialized_example.features().feature().at(
"c"),
162 EqualsProto(example_C().features().feature().at(
"c")));
163 EXPECT_THAT(serialized_example.features().feature().at(
"a"),
164 EqualsProto(example_A().features().feature().at(
"a")));
167 Example serialized_example;
168 ASSERT_TRUE(serialized_example.ParseFromString(vec(1)));
169 EXPECT_THAT(serialized_example.features().feature().at(
"c"),
170 EqualsProto(example_C(64).features().feature().at(
"c")));
174 TEST_F(InputUtilTest, ExampleListWithContext_NoContext) {
176 input_.mutable_example_list_with_context()->mutable_examples();
177 *examples->Add() = example_A();
178 *examples->Add() = example_B();
180 TF_ASSERT_OK(InputToSerializedExampleTensor(input_, &tensor_));
181 EXPECT_EQ(2, tensor_.NumElements());
182 const auto vec = tensor_.flat<tstring>();
183 ASSERT_EQ(vec.size(), 2);
185 Example serialized_example;
186 ASSERT_TRUE(serialized_example.ParseFromString(vec(0)));
187 EXPECT_THAT(serialized_example, EqualsProto(example_A()));
190 Example serialized_example;
191 ASSERT_TRUE(serialized_example.ParseFromString(vec(1)));
192 EXPECT_THAT(serialized_example, EqualsProto(example_B()));
196 TEST_F(InputUtilTest, ExampleListWithContext_OnlyContext) {
199 *input_.mutable_example_list_with_context()->mutable_context() = example_C();
201 const Status status = InputToSerializedExampleTensor(input_, &tensor_);
202 ASSERT_FALSE(status.ok());
203 EXPECT_THAT(status.message(), HasSubstr(
"Input is empty"));
206 TEST_F(InputUtilTest, RequestNumExamplesStreamz) {
208 *input_1.mutable_example_list()->mutable_examples()->Add() = example_A();
209 *input_1.mutable_example_list()->mutable_examples()->Add() = example_B();
211 TF_ASSERT_OK(InputToSerializedExampleTensor(input_1, &tensor_1));
212 EXPECT_EQ(2, tensor_1.NumElements());
215 *input_2.mutable_example_list()->mutable_examples()->Add() = example_C();
217 TF_ASSERT_OK(InputToSerializedExampleTensor(input_2, &tensor_2));
218 EXPECT_EQ(1, tensor_2.NumElements());
221 TEST(ExampleCountsTest, Simple) {
222 using histogram::Histogram;
224 const HistogramProto before_histogram =
225 internal::GetExampleCounts()->GetCell(
"model-name")->value();
226 const int before_count =
227 internal::GetExampleCountTotal()->GetCell(
"model-name")->value();
228 RecordRequestExampleCount(
"model-name", 3);
229 const HistogramProto after_histogram =
230 internal::GetExampleCounts()->GetCell(
"model-name")->value();
231 const int after_count =
232 internal::GetExampleCountTotal()->GetCell(
"model-name")->value();
234 ASSERT_GE(before_histogram.bucket().size(), 3);
235 ASSERT_GE(after_histogram.bucket().size(), 3);
236 EXPECT_EQ(1, after_histogram.bucket(2) - before_histogram.bucket(2));
237 EXPECT_EQ(3, after_count - before_count);
240 TEST(ModelSpecTest, NoOptional) {
241 ModelSpec model_spec;
242 MakeModelSpec(
"foo", {}, {}, &model_spec);
243 EXPECT_THAT(model_spec.name(), Eq(
"foo"));
244 EXPECT_THAT(model_spec.signature_name(), ::testing::IsEmpty());
245 EXPECT_FALSE(model_spec.has_version());
248 TEST(ModelSpecTest, OptionalSignature) {
249 ModelSpec model_spec;
250 MakeModelSpec(
"foo", {
"classify"}, {},
252 EXPECT_THAT(model_spec.name(), Eq(
"foo"));
253 EXPECT_THAT(model_spec.signature_name(), Eq(
"classify"));
254 EXPECT_FALSE(model_spec.has_version());
257 TEST(ModelSpecTest, EmptySignature) {
258 ModelSpec model_spec;
259 MakeModelSpec(
"foo", {
""}, {1}, &model_spec);
260 EXPECT_THAT(model_spec.name(), Eq(
"foo"));
261 EXPECT_THAT(model_spec.signature_name(), Eq(kDefaultServingSignatureDefKey));
262 EXPECT_THAT(model_spec.version().value(), Eq(1));
265 TEST(ModelSpecTest, OptionalVersion) {
266 ModelSpec model_spec;
267 MakeModelSpec(
"foo", {}, {1}, &model_spec);
268 EXPECT_THAT(model_spec.name(), Eq(
"foo"));
269 EXPECT_THAT(model_spec.signature_name(), ::testing::IsEmpty());
270 EXPECT_THAT(model_spec.version().value(), Eq(1));
273 TEST(ModelSpecTest, AllOptionalSet) {
274 ModelSpec model_spec;
275 MakeModelSpec(
"foo", {
"classify"}, {1},
277 EXPECT_THAT(model_spec.name(), Eq(
"foo"));
278 EXPECT_THAT(model_spec.signature_name(), Eq(
"classify"));
279 EXPECT_THAT(model_spec.version().value(), Eq(1));
282 TEST(SignatureMethodNameCheckFeature, SetGet) {
283 SetSignatureMethodNameCheckFeature(
true);
284 EXPECT_TRUE(GetSignatureMethodNameCheckFeature());
286 SetSignatureMethodNameCheckFeature(
false);
287 EXPECT_FALSE(GetSignatureMethodNameCheckFeature());
290 TEST(ResourceEstimatorTest, EstimateResourceFromPathUsingDiskState) {
291 const string export_dir =
"/foo/bar";
292 const string child =
"child";
293 const string child_path = io::JoinPath(export_dir, child);
294 const double file_size = 100;
298 test_util::MockFileProbingEnv env;
299 EXPECT_CALL(env, FileExists(export_dir))
300 .WillRepeatedly(Return(absl::OkStatus()));
301 EXPECT_CALL(env, GetChildren(export_dir, _))
302 .WillRepeatedly(DoAll(SetArgPointee<1>(std::vector<string>({child})),
303 Return(absl::OkStatus())));
304 EXPECT_CALL(env, IsDirectory(child_path))
305 .WillRepeatedly(Return(errors::FailedPrecondition(
"")));
306 EXPECT_CALL(env, GetFileSize(child_path, _))
308 DoAll(SetArgPointee<1>(file_size), Return(absl::OkStatus())));
310 ResourceAllocation actual;
312 EstimateResourceFromPathUsingDiskState(export_dir, &env, &actual));
314 ResourceAllocation expected =
315 test_util::GetExpectedResourceEstimate(file_size);
316 EXPECT_THAT(actual, EqualsProto(expected));
319 TEST(GetMapKeysTest, GetKeys) {
320 std::map<string, string> map = {std::pair<string, string>(
"key1",
"value1"),
321 std::pair<string, string>(
"key2",
"value2")};
322 const auto result = GetMapKeys(map);
323 EXPECT_THAT(result, ::testing::UnorderedElementsAre(
"key1",
"key2"));
326 TEST(SetDifferenceTEST, GetDiff) {
327 std::set<string> result;
328 EXPECT_THAT(SetDifference({
"a",
"b",
"c"}, {
"a",
"b"}),
329 ::testing::UnorderedElementsAre(
"c"));
330 EXPECT_THAT(SetDifference({
"a",
"b",
"c"}, {
"a",
"b",
"d"}),
331 ::testing::UnorderedElementsAre(
"c"));