TensorFlow Serving C++ API Documentation
test_util.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_SERVING_TEST_UTIL_TEST_UTIL_H_
17 #define TENSORFLOW_SERVING_TEST_UTIL_TEST_UTIL_H_
18 
19 #include <string>
20 
21 #include "google/protobuf/message.h"
22 #include "google/protobuf/text_format.h"
23 #include <gmock/gmock.h>
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/platform/threadpool.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/public/session_options.h"
31 
32 namespace tensorflow {
33 namespace serving {
34 namespace test_util {
35 
36 // Creates a proto message of type T from a textual representation.
37 template <typename T>
38 T CreateProto(const string& textual_proto);
39 
40 // Return an absolute runfiles srcdir given a path relative to
41 // tensorflow.
42 string TensorflowTestSrcDirPath(const string& relative_path);
43 
44 // Return an absolute runfiles srcdir given a path relative to
45 // tensorflow/contrib.
46 string ContribTestSrcDirPath(const string& relative_path);
47 
48 // Return an absolute runfiles srcdir given a path relative to
49 // tensorflow_serving.
50 string TestSrcDirPath(const string& relative_path);
51 
52 // Simple implementation of a proto matcher comparing string representations.
53 //
54 // IMPORTANT: Only use this for protos whose textual representation is
55 // deterministic (that may not be the case for the map collection type).
57  public:
58  explicit ProtoStringMatcher(const string& expected);
59  explicit ProtoStringMatcher(const google::protobuf::Message& expected);
60 
61  template <typename Message>
62  bool MatchAndExplain(const Message& p,
63  ::testing::MatchResultListener* /* listener */) const;
64 
65  void DescribeTo(::std::ostream* os) const { *os << expected_; }
66  void DescribeNegationTo(::std::ostream* os) const {
67  *os << "not equal to expected message: " << expected_;
68  }
69 
70  private:
71  const string expected_;
72 };
73 
74 // Polymorphic matcher to compare any two protos.
75 inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
76  const string& x) {
77  return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x));
78 }
79 
80 // Polymorphic matcher to compare any two protos.
81 inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
82  const google::protobuf::Message& x) {
83  return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x));
84 }
85 
87 // Implementation details. API readers need not read.
88 
89 template <typename T>
90 T CreateProto(const string& textual_proto) {
91  T proto;
92  CHECK(protobuf::TextFormat::ParseFromString(textual_proto, &proto));
93  return proto;
94 }
95 
96 template <typename Message>
97 bool ProtoStringMatcher::MatchAndExplain(
98  const Message& p, ::testing::MatchResultListener* /* listener */) const {
99  // Need to CreateProto and then print as string so that the formatting
100  // matches exactly.
101  return p.SerializeAsString() ==
102  CreateProto<Message>(expected_).SerializeAsString();
103 }
104 
105 // An implementation of thread::ThreadPoolInterface that delegates calls to
106 // thread::ThreadPool but
107 class CountingThreadPool : public thread::ThreadPoolInterface {
108  public:
109  CountingThreadPool(Env* env, const string& name, int num_threads)
110  : thread_pool_(env, name, num_threads), num_scheduled_(0) {}
111  ~CountingThreadPool() = default;
112 
113  void Schedule(std::function<void()> fn) override {
114  {
115  mutex_lock l(mu_);
116  num_scheduled_++;
117  }
118  thread_pool_.Schedule(fn);
119  }
120 
121  int NumThreads() const override { return thread_pool_.NumThreads(); }
122 
123  int CurrentThreadId() const override {
124  return thread_pool_.CurrentThreadId();
125  }
126 
127  int NumScheduled() const {
128  {
129  mutex_lock l(mu_);
130  return num_scheduled_;
131  }
132  }
133 
134  private:
135  thread::ThreadPool thread_pool_;
136  mutable mutex mu_;
137  int32 num_scheduled_ TF_GUARDED_BY(mu_);
138 };
139 
140 } // namespace test_util
141 } // namespace serving
142 } // namespace tensorflow
143 
144 #endif // TENSORFLOW_SERVING_TEST_UTIL_TEST_UTIL_H_