TensorFlow Serving C++ API Documentation
request_logger_test.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/core/request_logger.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "google/protobuf/any.pb.h"
22 #include "google/protobuf/wrappers.pb.h"
23 #include "google/protobuf/message.h"
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include "tensorflow/cc/saved_model/tag_constants.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow_serving/apis/logging.pb.h"
32 #include "tensorflow_serving/apis/model.pb.h"
33 #include "tensorflow_serving/apis/predict.pb.h"
34 #include "tensorflow_serving/config/logging_config.pb.h"
35 #include "tensorflow_serving/core/log_collector.h"
36 #include "tensorflow_serving/core/test_util/mock_log_collector.h"
37 #include "tensorflow_serving/core/test_util/mock_prediction_stream_logger.h"
38 #include "tensorflow_serving/core/test_util/mock_request_logger.h"
39 #include "tensorflow_serving/test_util/test_util.h"
40 
41 namespace tensorflow {
42 namespace serving {
43 namespace {
44 
45 using test_util::MockPredictionStreamLogger;
46 using ::testing::_;
47 using ::testing::DoAll;
48 using ::testing::HasSubstr;
49 using ::testing::Invoke;
50 using ::testing::NiceMock;
51 using ::testing::Return;
52 using ::testing::WithArg;
53 
54 class RequestLoggerTest : public ::testing::Test {
55  protected:
56  RequestLoggerTest() {
57  LoggingConfig logging_config;
58  logging_config.mutable_sampling_config()->set_sampling_rate(1.0);
59  log_collector_ = new NiceMock<MockLogCollector>();
60  request_logger_ = std::shared_ptr<NiceMock<MockRequestLogger>>(
61  new NiceMock<MockRequestLogger>(logging_config, model_tags_,
62  log_collector_));
63  }
64 
65  const std::vector<string> model_tags_ = {kSavedModelTagServe,
66  kSavedModelTagTpu};
67  NiceMock<MockLogCollector>* log_collector_;
68  std::shared_ptr<NiceMock<MockRequestLogger>> request_logger_;
69 };
70 
71 TEST_F(RequestLoggerTest, Simple) {
72  ModelSpec model_spec;
73  model_spec.set_name("model");
74  model_spec.mutable_version()->set_value(10);
75 
76  PredictRequest request;
77  *request.mutable_model_spec() = model_spec;
78 
79  PredictResponse response;
80  response.mutable_outputs()->insert({"tensor", TensorProto()});
81  LogMetadata log_metadata;
82  *log_metadata.mutable_model_spec() = model_spec;
83 
84  EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
85  .WillOnce(Invoke([&](const google::protobuf::Message& actual_request,
86  const google::protobuf::Message& actual_response,
87  const LogMetadata& actual_log_metadata,
88  std::unique_ptr<google::protobuf::Message>* log) {
89  EXPECT_THAT(static_cast<const PredictRequest&>(actual_request),
90  test_util::EqualsProto(request));
91  EXPECT_THAT(static_cast<const PredictResponse&>(actual_response),
92  test_util::EqualsProto(PredictResponse()));
93  LogMetadata expected_log_metadata = log_metadata;
94  expected_log_metadata.mutable_sampling_config()->set_sampling_rate(1.0);
95  *expected_log_metadata.mutable_saved_model_tags() = {
96  model_tags_.begin(), model_tags_.end()};
97  EXPECT_THAT(actual_log_metadata,
98  test_util::EqualsProto(expected_log_metadata));
99  *log =
100  std::unique_ptr<google::protobuf::Any>(new google::protobuf::Any());
101  return OkStatus();
102  }));
103  EXPECT_CALL(*log_collector_, CollectMessage(_)).WillOnce(Return(OkStatus()));
104  TF_ASSERT_OK(request_logger_->Log(request, PredictResponse(), log_metadata));
105 }
106 
107 TEST_F(RequestLoggerTest, ErroringCreateLogMessage) {
108  EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
109  .WillRepeatedly(Return(errors::Internal("Error")));
110  EXPECT_CALL(*log_collector_, CollectMessage(_)).Times(0);
111  const auto error_status =
112  request_logger_->Log(PredictRequest(), PredictResponse(), LogMetadata());
113  ASSERT_FALSE(error_status.ok());
114  EXPECT_THAT(error_status.message(), HasSubstr("Error"));
115 }
116 
117 TEST_F(RequestLoggerTest, ErroringCollectMessage) {
118  EXPECT_CALL(*request_logger_, CreateLogMessage(_, _, _, _))
119  .WillRepeatedly(Invoke([&](const google::protobuf::Message& actual_request,
120  const google::protobuf::Message& actual_response,
121  const LogMetadata& actual_log_metadata,
122  std::unique_ptr<google::protobuf::Message>* log) {
123  *log =
124  std::unique_ptr<google::protobuf::Any>(new google::protobuf::Any());
125  return OkStatus();
126  }));
127  EXPECT_CALL(*log_collector_, CollectMessage(_))
128  .WillRepeatedly(Return(errors::Internal("Error")));
129  const auto error_status =
130  request_logger_->Log(PredictRequest(), PredictResponse(), LogMetadata());
131  ASSERT_FALSE(error_status.ok());
132  EXPECT_THAT(error_status.message(), HasSubstr("Error"));
133 }
134 
135 TEST_F(RequestLoggerTest, LoggingStreamSucceeds) {
136  auto logger = std::make_unique<MockPredictionStreamLogger>();
137 
138  LogMetadata expected_log_metadata;
139  expected_log_metadata.mutable_model_spec()->set_name("model");
140  EXPECT_CALL(*request_logger_,
141  FillLogMetadata(test_util::EqualsProto(expected_log_metadata)))
142  .WillOnce(Return(expected_log_metadata));
143 
144  request_logger_->MaybeStartLoggingStream<PredictRequest, PredictResponse>(
145  expected_log_metadata,
146  [logger_ptr = logger.get()]() { return logger_ptr; });
147 
148  EXPECT_CALL(*logger, CreateLogMessage(
149  test_util::EqualsProto(expected_log_metadata), _))
150  .WillOnce(DoAll(WithArg<1>([](std::unique_ptr<google::protobuf::Message>* log) {
151  *log = std::make_unique<google::protobuf::Any>();
152  }),
153  Return(OkStatus())));
154  EXPECT_CALL(*log_collector_, CollectMessage(_)).WillOnce(Return(OkStatus()));
155  TF_ASSERT_OK(logger->LogMessage());
156 }
157 
158 TEST_F(RequestLoggerTest, LoggingStreamRequestLoggerDiesBeforeStreamCloses) {
159  auto logger = std::make_unique<MockPredictionStreamLogger>();
160 
161  LogMetadata expected_log_metadata;
162  expected_log_metadata.mutable_model_spec()->set_name("model");
163 
164  EXPECT_CALL(*request_logger_,
165  FillLogMetadata(test_util::EqualsProto(expected_log_metadata)))
166  .WillOnce(Return(expected_log_metadata));
167  request_logger_->MaybeStartLoggingStream<PredictRequest, PredictResponse>(
168  expected_log_metadata,
169  [logger_ptr = logger.get()]() { return logger_ptr; });
170 
171  EXPECT_CALL(*logger, CreateLogMessage(
172  test_util::EqualsProto(expected_log_metadata), _))
173  .WillOnce(DoAll(WithArg<1>([](std::unique_ptr<google::protobuf::Message>* log) {
174  *log = std::make_unique<google::protobuf::Any>();
175  }),
176  Return(OkStatus())));
177  EXPECT_CALL(*log_collector_, CollectMessage(_)).Times(0);
178 
179  request_logger_.reset();
180  TF_ASSERT_OK(logger->LogMessage());
181 }
182 
183 } // namespace
184 } // namespace serving
185 } // namespace tensorflow