16 #include "tensorflow_serving/batching/batching_util.h"
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/types.h"
30 namespace tensorflow {
34 using ::testing::ElementsAre;
35 using ::testing::Pair;
36 using ::testing::UnorderedElementsAre;
41 std::vector<std::pair<string, Tensor>> CreateInputsWithTensorShapes(
42 const std::vector<TensorShape>& shapes) {
43 std::vector<std::pair<string, Tensor>> inputs;
44 for (
int i = 0; i < shapes.size(); ++i) {
45 inputs.push_back({
"x" + std::to_string(i), Tensor(DT_FLOAT, shapes[i])});
50 TEST(BatchingUtilTest, CalculateMaxDimSizes) {
51 const std::vector<TensorShape> shapes1{{10, 20, 30}, {10, 100}};
52 std::vector<std::pair<string, Tensor>> inputs1 =
53 CreateInputsWithTensorShapes(shapes1);
54 const std::vector<TensorShape> shapes2{{20, 50, 15}, {20, 101}};
55 std::vector<std::pair<string, Tensor>> inputs2 =
56 CreateInputsWithTensorShapes(shapes2);
57 std::vector<std::vector<std::pair<string, Tensor>>> batch{inputs1, inputs2};
58 std::map<string, std::vector<int>> max_dim_sizes =
59 CalculateMaxDimSizes(batch);
60 EXPECT_THAT(max_dim_sizes,
61 UnorderedElementsAre(Pair(
"x0", ElementsAre(20, 50, 30)),
62 Pair(
"x1", ElementsAre(20, 101))));
65 TEST(BatchingUtilTest, AddPadding) {
66 const std::vector<int> max_dim_sizes{20, 100, 200};
67 const std::vector<DataType> types{
68 DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
69 DT_UINT16, DT_INT8, DT_STRING, DT_BOOL, DT_COMPLEX64,
70 DT_COMPLEX128, DT_INT64, DT_QINT8, DT_QUINT8, DT_QINT16,
71 DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE};
72 Status padding_status;
73 for (DataType type : types) {
74 Tensor tensor(type, {10, 20, 30});
75 #define INIT_TYPE(T) \
76 if (type == DataTypeToEnum<T>::value) { \
77 tensor.flat<T>().setConstant(T()); \
79 TF_CALL_ALL_TYPES(INIT_TYPE);
80 TF_CALL_QUANTIZED_TYPES(INIT_TYPE);
82 TF_CALL_quint16(INIT_TYPE);
83 TF_CALL_qint16(INIT_TYPE);
86 padding_status = AddPadding(tensor, max_dim_sizes, &padded_tensor);
87 ASSERT_EQ(absl::OkStatus(), padding_status);
88 EXPECT_EQ(TensorShape({10, 100, 200}), padded_tensor.shape());
92 TEST(BatchingUtilTest, AddPaddingTensorWithUnsupportedRank) {
93 const std::vector<int> max_dim_sizes{1, 1, 1, 1, 1, 1, 1};
94 const Tensor tensor(DT_FLOAT, {1, 1, 1, 1, 1, 1, 1});
96 ASSERT_EQ(errors::InvalidArgument(
97 "Only tensors with rank from 1 to 6 can be padded."),
98 AddPadding(tensor, max_dim_sizes, &padded_tensor));