16 #include "tensorflow_serving/core/request_logger.h"
21 #include "google/protobuf/any.pb.h"
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/message.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/cc/saved_model/tag_constants.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow_serving/apis/logging.pb.h"
32 #include "tensorflow_serving/apis/model.pb.h"
33 #include "tensorflow_serving/apis/predict.pb.h"
34 #include "tensorflow_serving/config/logging_config.pb.h"
35 #include "tensorflow_serving/core/log_collector.h"
36 #include "tensorflow_serving/core/test_util/mock_log_collector.h"
37 #include "tensorflow_serving/core/test_util/mock_prediction_stream_logger.h"
38 #include "tensorflow_serving/core/test_util/mock_request_logger.h"
39 #include "tensorflow_serving/test_util/test_util.h"
41 namespace tensorflow {
45 using test_util::MockPredictionStreamLogger;
47 using ::testing::DoAll;
48 using ::testing::HasSubstr;
49 using ::testing::Invoke;
50 using ::testing::NiceMock;
51 using ::testing::Return;
52 using ::testing::WithArg;
54 class RequestLoggerTest :
public ::testing::Test {
57 LoggingConfig logging_config;
58 logging_config.mutable_sampling_config()->set_sampling_rate(1.0);
59 log_collector_ =
new NiceMock<MockLogCollector>();
60 request_logger_ = std::shared_ptr<NiceMock<MockRequestLogger>>(
61 new NiceMock<MockRequestLogger>(logging_config, model_tags_,
65 const std::vector<string> model_tags_ = {kSavedModelTagServe,
67 NiceMock<MockLogCollector>* log_collector_;
68 std::shared_ptr<NiceMock<MockRequestLogger>> request_logger_;
71 TEST_F(RequestLoggerTest, Simple) {
73 model_spec.set_name(
"model");
74 model_spec.mutable_version()->set_value(10);
76 PredictRequest request;
77 *request.mutable_model_spec() = model_spec;
79 PredictResponse response;
80 response.mutable_outputs()->insert({
"tensor", TensorProto()});
81 LogMetadata log_metadata;
82 *log_metadata.mutable_model_spec() = model_spec;
84 EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
85 .WillOnce(Invoke([&](
const google::protobuf::Message& actual_request,
86 const google::protobuf::Message& actual_response,
87 const LogMetadata& actual_log_metadata,
88 std::unique_ptr<google::protobuf::Message>* log) {
89 EXPECT_THAT(
static_cast<const PredictRequest&
>(actual_request),
90 test_util::EqualsProto(request));
91 EXPECT_THAT(
static_cast<const PredictResponse&
>(actual_response),
92 test_util::EqualsProto(PredictResponse()));
93 LogMetadata expected_log_metadata = log_metadata;
94 expected_log_metadata.mutable_sampling_config()->set_sampling_rate(1.0);
95 *expected_log_metadata.mutable_saved_model_tags() = {
96 model_tags_.begin(), model_tags_.end()};
97 EXPECT_THAT(actual_log_metadata,
98 test_util::EqualsProto(expected_log_metadata));
100 std::unique_ptr<google::protobuf::Any>(
new google::protobuf::Any());
103 EXPECT_CALL(*log_collector_, CollectMessage(_)).WillOnce(Return(OkStatus()));
104 TF_ASSERT_OK(request_logger_->Log(request, PredictResponse(), log_metadata));
107 TEST_F(RequestLoggerTest, ErroringCreateLogMessage) {
108 EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
109 .WillRepeatedly(Return(errors::Internal(
"Error")));
110 EXPECT_CALL(*log_collector_, CollectMessage(_)).Times(0);
111 const auto error_status =
112 request_logger_->Log(PredictRequest(), PredictResponse(), LogMetadata());
113 ASSERT_FALSE(error_status.ok());
114 EXPECT_THAT(error_status.message(), HasSubstr(
"Error"));
117 TEST_F(RequestLoggerTest, ErroringCollectMessage) {
118 EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
119 .WillRepeatedly(Invoke([&](
const google::protobuf::Message& actual_request,
120 const google::protobuf::Message& actual_response,
121 const LogMetadata& actual_log_metadata,
122 std::unique_ptr<google::protobuf::Message>* log) {
124 std::unique_ptr<google::protobuf::Any>(
new google::protobuf::Any());
127 EXPECT_CALL(*log_collector_, CollectMessage(_))
128 .WillRepeatedly(Return(errors::Internal(
"Error")));
129 const auto error_status =
130 request_logger_->Log(PredictRequest(), PredictResponse(), LogMetadata());
131 ASSERT_FALSE(error_status.ok());
132 EXPECT_THAT(error_status.message(), HasSubstr(
"Error"));
135 TEST_F(RequestLoggerTest, LoggingStreamSucceeds) {
136 auto logger = std::make_unique<MockPredictionStreamLogger>();
138 LogMetadata expected_log_metadata;
139 expected_log_metadata.mutable_model_spec()->set_name(
"model");
140 EXPECT_CALL(*request_logger_,
141 FillLogMetadata(test_util::EqualsProto(expected_log_metadata)))
142 .WillOnce(Return(expected_log_metadata));
144 request_logger_->MaybeStartLoggingStream<PredictRequest, PredictResponse>(
145 expected_log_metadata,
146 [logger_ptr = logger.get()]() {
return logger_ptr; });
148 EXPECT_CALL(*logger, CreateLogMessage(
149 test_util::EqualsProto(expected_log_metadata), _))
150 .WillOnce(DoAll(WithArg<1>([](std::unique_ptr<google::protobuf::Message>* log) {
151 *log = std::make_unique<google::protobuf::Any>();
153 Return(OkStatus())));
154 EXPECT_CALL(*log_collector_, CollectMessage(_)).WillOnce(Return(OkStatus()));
155 TF_ASSERT_OK(logger->LogMessage());
158 TEST_F(RequestLoggerTest, LoggingStreamRequestLoggerDiesBeforeStreamCloses) {
159 auto logger = std::make_unique<MockPredictionStreamLogger>();
161 LogMetadata expected_log_metadata;
162 expected_log_metadata.mutable_model_spec()->set_name(
"model");
164 EXPECT_CALL(*request_logger_,
165 FillLogMetadata(test_util::EqualsProto(expected_log_metadata)))
166 .WillOnce(Return(expected_log_metadata));
167 request_logger_->MaybeStartLoggingStream<PredictRequest, PredictResponse>(
168 expected_log_metadata,
169 [logger_ptr = logger.get()]() {
return logger_ptr; });
171 EXPECT_CALL(*logger, CreateLogMessage(
172 test_util::EqualsProto(expected_log_metadata), _))
173 .WillOnce(DoAll(WithArg<1>([](std::unique_ptr<google::protobuf::Message>* log) {
174 *log = std::make_unique<google::protobuf::Any>();
176 Return(OkStatus())));
177 EXPECT_CALL(*log_collector_, CollectMessage(_)).Times(0);
179 request_logger_.reset();
180 TF_ASSERT_OK(logger->LogMessage());