TensorFlow Serving C++ API Documentation
class_registration_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/util/class_registration.h"
17 
18 #include <map>
19 #include <memory>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow_serving/util/class_registration_test.pb.h"
25 
26 using ::testing::Pair;
27 using ::testing::UnorderedElementsAre;
28 
29 namespace tensorflow {
30 namespace serving {
31 namespace {
32 
33 // A base class and associated registry.
34 class MyBaseClass {
35  public:
36  virtual ~MyBaseClass() = default;
37  virtual string class_name() const = 0;
38  virtual string config_data() const = 0;
39 };
40 DEFINE_CLASS_REGISTRY(MyBaseClassRegistry, MyBaseClass);
41 #define REGISTER_MY_BASE_CLASS(SubClassCreator, ConfigProto) \
42  REGISTER_CLASS(MyBaseClassRegistry, MyBaseClass, SubClassCreator, \
43  ConfigProto);
44 
45 // A subclass of MyBaseClass that should be instantiated via Config1.
46 class SubClass1 : public MyBaseClass {
47  public:
48  static Status Create(const Config1& config,
49  std::unique_ptr<MyBaseClass>* result) {
50  if (config.GetDescriptor()->full_name() != "tensorflow.serving.Config1") {
51  return errors::InvalidArgument("Wrong type of config proto: ",
52  config.GetDescriptor()->full_name());
53  }
54  auto* raw_result = new SubClass1();
55  result->reset(raw_result);
56  raw_result->config_ = config;
57  return OkStatus();
58  }
59 
60  string class_name() const override { return "SubClass1"; }
61 
62  string config_data() const override { return config_.string_field(); }
63 
64  private:
65  Config1 config_;
66 };
67 REGISTER_MY_BASE_CLASS(SubClass1, Config1);
68 
69 // A subclass of MyBaseClass that should be instantiated via Config2.
70 class SubClass2 : public MyBaseClass {
71  public:
72  static Status Create(const Config2& config,
73  std::unique_ptr<MyBaseClass>* result) {
74  if (config.GetDescriptor()->full_name() != "tensorflow.serving.Config2") {
75  return errors::InvalidArgument("Wrong type of config proto: ",
76  config.GetDescriptor()->full_name());
77  }
78  auto* raw_result = new SubClass2();
79  result->reset(raw_result);
80  raw_result->config_ = config;
81  return OkStatus();
82  }
83 
84  string class_name() const override { return "SubClass2"; }
85 
86  string config_data() const override { return config_.string_field(); }
87 
88  private:
89  Config2 config_;
90 };
91 
92 // A creator of SubClass2 objects.
93 class SubClass2Creator {
94  public:
95  static Status Create(const Config2& config,
96  std::unique_ptr<MyBaseClass>* result) {
97  return SubClass2::Create(config, result);
98  }
99 };
100 REGISTER_MY_BASE_CLASS(SubClass2Creator, Config2);
101 
102 TEST(ClassRegistrationTest, InstantiateFromRawConfig) {
103  std::unique_ptr<MyBaseClass> loaded_subclass;
104 
105  Config1 config1;
106  config1.set_string_field("foo");
107  TF_ASSERT_OK(MyBaseClassRegistry::Create(config1, &loaded_subclass));
108  EXPECT_EQ("SubClass1", loaded_subclass->class_name());
109  EXPECT_EQ("foo", loaded_subclass->config_data());
110 
111  Config2 config2;
112  config2.set_string_field("bar");
113  TF_ASSERT_OK(MyBaseClassRegistry::Create(config2, &loaded_subclass));
114  EXPECT_EQ("SubClass2", loaded_subclass->class_name());
115  EXPECT_EQ("bar", loaded_subclass->config_data());
116 }
117 
118 TEST(ClassRegistrationTest, InstantiateFromAny) {
119  std::unique_ptr<MyBaseClass> loaded_subclass;
120 
121  Config1 config1;
122  config1.set_string_field("foo");
123  google::protobuf::Any any_config1;
124  any_config1.PackFrom(config1);
125  TF_ASSERT_OK(
126  MyBaseClassRegistry::CreateFromAny(any_config1, &loaded_subclass));
127  EXPECT_EQ("SubClass1", loaded_subclass->class_name());
128  EXPECT_EQ("foo", loaded_subclass->config_data());
129 
130  Config2 config2;
131  config2.set_string_field("bar");
132  google::protobuf::Any any_config2;
133  any_config2.PackFrom(config2);
134  TF_ASSERT_OK(
135  MyBaseClassRegistry::CreateFromAny(any_config2, &loaded_subclass));
136  EXPECT_EQ("SubClass2", loaded_subclass->class_name());
137  EXPECT_EQ("bar", loaded_subclass->config_data());
138 }
139 
140 // A second registry for MyBaseClass, with a different name.
141 DEFINE_CLASS_REGISTRY(AlternateMyBaseClassRegistry, MyBaseClass);
142 #define REGISTER_MY_BASE_CLASS_USING_ALTERNATE_REGISTRY(SubClassCreator, \
143  ConfigProto) \
144  REGISTER_CLASS(AlternateMyBaseClassRegistry, MyBaseClass, SubClassCreator, \
145  ConfigProto);
146 
147 // A subclass of MyBaseClass that should be instantiated via Config1, and is
148 // registered in the alternate registry.
149 class AlternateSubClass : public MyBaseClass {
150  public:
151  static Status Create(const Config1& config,
152  std::unique_ptr<MyBaseClass>* result) {
153  if (config.GetDescriptor()->full_name() != "tensorflow.serving.Config1") {
154  return errors::InvalidArgument("Wrong type of config proto: ",
155  config.GetDescriptor()->full_name());
156  }
157  auto* raw_result = new AlternateSubClass();
158  result->reset(raw_result);
159  raw_result->config_ = config;
160  return OkStatus();
161  }
162 
163  string class_name() const override { return "AlternateSubClass"; }
164 
165  string config_data() const override { return config_.string_field(); }
166 
167  private:
168  Config1 config_;
169 };
170 REGISTER_MY_BASE_CLASS_USING_ALTERNATE_REGISTRY(AlternateSubClass, Config1);
171 
172 TEST(ClassRegistrationTest, MultipleRegistriesForSameBaseClass) {
173  std::unique_ptr<MyBaseClass> loaded_subclass;
174 
175  Config1 config;
176  TF_ASSERT_OK(MyBaseClassRegistry::Create(config, &loaded_subclass));
177  EXPECT_EQ("SubClass1", loaded_subclass->class_name());
178 
179  TF_ASSERT_OK(AlternateMyBaseClassRegistry::Create(config, &loaded_subclass));
180  EXPECT_EQ("AlternateSubClass", loaded_subclass->class_name());
181 }
182 
183 // A base class whose subclasses' Create() methods take additional parameters,
184 // and associated registry.
185 class MyParameterizedBaseClass {
186  public:
187  virtual ~MyParameterizedBaseClass() = default;
188  virtual string class_name() const = 0;
189  virtual string config_data() const = 0;
190  virtual int param1_data() const = 0;
191  virtual string param2_data() const = 0;
192  virtual const std::map<string, int>& param3_data() const = 0;
193 };
194 DEFINE_CLASS_REGISTRY(MyParameterizedBaseClassRegistry,
195  MyParameterizedBaseClass, int, const string&,
196  const std::map<string TFS_COMMA int>&);
197 #define REGISTER_MY_PARAMETERIZED_BASE_CLASS(SubClassCreator, ConfigProto) \
198  REGISTER_CLASS(MyParameterizedBaseClassRegistry, MyParameterizedBaseClass, \
199  SubClassCreator, ConfigProto, int, const string&, \
200  const std::map<string TFS_COMMA int>&);
201 
202 // A subclass of MyParameterizedBaseClass that should be instantiated via
203 // Config1.
204 class ParameterizedSubClass1 : public MyParameterizedBaseClass {
205  public:
206  static Status Create(const Config1& config, int param1, const string& param2,
207  const std::map<string, int>& param3,
208  std::unique_ptr<MyParameterizedBaseClass>* result) {
209  if (config.GetDescriptor()->full_name() != "tensorflow.serving.Config1") {
210  return errors::InvalidArgument("Wrong type of config proto: ",
211  config.GetDescriptor()->full_name());
212  }
213  auto* raw_result = new ParameterizedSubClass1();
214  result->reset(raw_result);
215  raw_result->config_ = config;
216  raw_result->param1_ = param1;
217  raw_result->param2_ = param2;
218  raw_result->param3_ = param3;
219  return OkStatus();
220  }
221 
222  string class_name() const override { return "ParameterizedSubClass1"; }
223 
224  string config_data() const override { return config_.string_field(); }
225 
226  int param1_data() const override { return param1_; }
227 
228  string param2_data() const override { return param2_; }
229 
230  const std::map<string, int>& param3_data() const override { return param3_; }
231 
232  private:
233  Config1 config_;
234  int param1_;
235  string param2_;
236  std::map<string, int> param3_;
237 };
238 REGISTER_MY_PARAMETERIZED_BASE_CLASS(ParameterizedSubClass1, Config1);
239 
240 TEST(ClassRegistrationTest, InstantiateParameterizedFromRawConfig) {
241  std::unique_ptr<MyParameterizedBaseClass> loaded_subclass;
242 
243  Config1 config1;
244  config1.set_string_field("foo");
245  TF_ASSERT_OK(MyParameterizedBaseClassRegistry::Create(
246  config1, 42, "bar", {{"floop", 1}, {"mrop", 2}}, &loaded_subclass));
247  EXPECT_EQ("ParameterizedSubClass1", loaded_subclass->class_name());
248  EXPECT_EQ("foo", loaded_subclass->config_data());
249  EXPECT_EQ(42, loaded_subclass->param1_data());
250  EXPECT_EQ("bar", loaded_subclass->param2_data());
251  EXPECT_THAT(loaded_subclass->param3_data(),
252  UnorderedElementsAre(Pair("floop", 1), Pair("mrop", 2)));
253 }
254 
255 TEST(ClassRegistrationTest, InstantiateParameterizedFromAny) {
256  std::unique_ptr<MyParameterizedBaseClass> loaded_subclass;
257 
258  Config1 config1;
259  config1.set_string_field("foo");
260  google::protobuf::Any any_config1;
261  any_config1.PackFrom(config1);
262  TF_ASSERT_OK(MyParameterizedBaseClassRegistry::CreateFromAny(
263  any_config1, 42, "bar", {{"floop", 1}, {"mrop", 2}}, &loaded_subclass));
264  EXPECT_EQ("ParameterizedSubClass1", loaded_subclass->class_name());
265  EXPECT_EQ("foo", loaded_subclass->config_data());
266  EXPECT_EQ(42, loaded_subclass->param1_data());
267  EXPECT_EQ("bar", loaded_subclass->param2_data());
268  EXPECT_THAT(loaded_subclass->param3_data(),
269  UnorderedElementsAre(Pair("floop", 1), Pair("mrop", 2)));
270 }
271 
272 } // namespace
273 } // namespace serving
274 } // namespace tensorflow