TensorFlow Serving C++ API Documentation
stream_logger.h
1 /* Copyright 2023 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef THIRD_PARTY_TENSORFLOW_SERVING_CORE_STREAM_LOGGER_H_
17 #define THIRD_PARTY_TENSORFLOW_SERVING_CORE_STREAM_LOGGER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/status/status.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow_serving/apis/logging.pb.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 
32 // Simple logger for a stream of requests and responses. In practice, the
33 // lifetime of this class should be attached to the lifetime of a stream.
34 //
35 // The class being templated on requests and responses is to avoid RTTI in the
36 // subclasses.
37 // Not thread-safe.
38 template <typename Request, typename Response>
39 class StreamLogger {
40  public:
41  StreamLogger() {
42  static_assert((std::is_base_of<google::protobuf::Message, Request>::value),
43  "Request must be a proto type.");
44  static_assert((std::is_base_of<google::protobuf::Message, Response>::value),
45  "Response must be a proto type.");
46  }
47 
48  virtual ~StreamLogger() = default;
49 
50  virtual void LogStreamRequest(Request request) = 0;
51  virtual void LogStreamResponse(Response response) = 0;
52 
53  using LogMessageFn = std::function<absl::Status(const google::protobuf::Message&)>;
54  // Registers a log callback to be invoked when calling LogMessage();
55  void AddLogCallback(const LogMetadata& log_metadata,
56  LogMessageFn log_message_fn);
57 
58  // Logs the message with all requests and responses accumulated so far, and
59  // invokes all log callbacks sequentially. Upon return, any subsequent calls
60  // to any other methods of this class will result in undefined behavior. On
61  // multiple callbacks, we return error from the first failed one (and continue
62  // attempting the rest).
63  absl::Status LogMessage();
64 
65  private:
66  virtual absl::Status CreateLogMessage(
67  const LogMetadata& log_metadata,
68  std::unique_ptr<google::protobuf::Message>* log) = 0;
69 
70  struct StreamLogCallback {
71  LogMetadata log_metadata;
72  LogMessageFn log_message_fn;
73  };
74 
75  std::vector<StreamLogCallback> callbacks_;
76 };
77 
78 /*************************Implementation Details******************************/
79 
80 template <typename Request, typename Response>
82  absl::Status status;
83  for (const auto& callback : callbacks_) {
84  std::unique_ptr<google::protobuf::Message> log;
85  absl::Status create_status = CreateLogMessage(callback.log_metadata, &log);
86  if (create_status.ok()) {
87  status.Update(callback.log_message_fn(*log));
88  } else {
89  LOG_EVERY_N_SEC(ERROR, 30)
90  << "Failed creating log message for streaming request. Log metadata: "
91  << callback.log_metadata.DebugString()
92  << ", error: " << create_status;
93  status.Update(create_status);
94  }
95  }
96  return status;
97 }
98 
99 template <typename Request, typename Response>
100 void StreamLogger<Request, Response>::AddLogCallback(
101  const LogMetadata& log_metadata, LogMessageFn log_message_fn) {
102  StreamLogCallback callback;
103  callback.log_metadata = log_metadata;
104  callback.log_message_fn = std::move(log_message_fn);
105  callbacks_.push_back(std::move(callback));
106 }
107 
108 } // namespace serving
109 } // namespace tensorflow
110 
111 #endif // THIRD_PARTY_TENSORFLOW_SERVING_CORE_STREAM_LOGGER_H_