16 #include "tensorflow_serving/util/class_registration.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow_serving/util/class_registration_test.pb.h"
26 using ::testing::Pair;
27 using ::testing::UnorderedElementsAre;
29 namespace tensorflow {
36 virtual ~MyBaseClass() =
default;
37 virtual string class_name()
const = 0;
38 virtual string config_data()
const = 0;
40 DEFINE_CLASS_REGISTRY(MyBaseClassRegistry, MyBaseClass);
41 #define REGISTER_MY_BASE_CLASS(SubClassCreator, ConfigProto) \
42 REGISTER_CLASS(MyBaseClassRegistry, MyBaseClass, SubClassCreator, \
46 class SubClass1 :
public MyBaseClass {
48 static Status Create(
const Config1& config,
49 std::unique_ptr<MyBaseClass>* result) {
50 if (config.GetDescriptor()->full_name() !=
"tensorflow.serving.Config1") {
51 return errors::InvalidArgument(
"Wrong type of config proto: ",
52 config.GetDescriptor()->full_name());
54 auto* raw_result =
new SubClass1();
55 result->reset(raw_result);
56 raw_result->config_ = config;
60 string class_name()
const override {
return "SubClass1"; }
62 string config_data()
const override {
return config_.string_field(); }
67 REGISTER_MY_BASE_CLASS(SubClass1, Config1);
70 class SubClass2 :
public MyBaseClass {
72 static Status Create(
const Config2& config,
73 std::unique_ptr<MyBaseClass>* result) {
74 if (config.GetDescriptor()->full_name() !=
"tensorflow.serving.Config2") {
75 return errors::InvalidArgument(
"Wrong type of config proto: ",
76 config.GetDescriptor()->full_name());
78 auto* raw_result =
new SubClass2();
79 result->reset(raw_result);
80 raw_result->config_ = config;
84 string class_name()
const override {
return "SubClass2"; }
86 string config_data()
const override {
return config_.string_field(); }
93 class SubClass2Creator {
95 static Status Create(
const Config2& config,
96 std::unique_ptr<MyBaseClass>* result) {
97 return SubClass2::Create(config, result);
100 REGISTER_MY_BASE_CLASS(SubClass2Creator, Config2);
102 TEST(ClassRegistrationTest, InstantiateFromRawConfig) {
103 std::unique_ptr<MyBaseClass> loaded_subclass;
106 config1.set_string_field(
"foo");
107 TF_ASSERT_OK(MyBaseClassRegistry::Create(config1, &loaded_subclass));
108 EXPECT_EQ(
"SubClass1", loaded_subclass->class_name());
109 EXPECT_EQ(
"foo", loaded_subclass->config_data());
112 config2.set_string_field(
"bar");
113 TF_ASSERT_OK(MyBaseClassRegistry::Create(config2, &loaded_subclass));
114 EXPECT_EQ(
"SubClass2", loaded_subclass->class_name());
115 EXPECT_EQ(
"bar", loaded_subclass->config_data());
118 TEST(ClassRegistrationTest, InstantiateFromAny) {
119 std::unique_ptr<MyBaseClass> loaded_subclass;
122 config1.set_string_field(
"foo");
123 google::protobuf::Any any_config1;
124 any_config1.PackFrom(config1);
126 MyBaseClassRegistry::CreateFromAny(any_config1, &loaded_subclass));
127 EXPECT_EQ(
"SubClass1", loaded_subclass->class_name());
128 EXPECT_EQ(
"foo", loaded_subclass->config_data());
131 config2.set_string_field(
"bar");
132 google::protobuf::Any any_config2;
133 any_config2.PackFrom(config2);
135 MyBaseClassRegistry::CreateFromAny(any_config2, &loaded_subclass));
136 EXPECT_EQ(
"SubClass2", loaded_subclass->class_name());
137 EXPECT_EQ(
"bar", loaded_subclass->config_data());
141 DEFINE_CLASS_REGISTRY(AlternateMyBaseClassRegistry, MyBaseClass);
142 #define REGISTER_MY_BASE_CLASS_USING_ALTERNATE_REGISTRY(SubClassCreator, \
144 REGISTER_CLASS(AlternateMyBaseClassRegistry, MyBaseClass, SubClassCreator, \
149 class AlternateSubClass :
public MyBaseClass {
151 static Status Create(
const Config1& config,
152 std::unique_ptr<MyBaseClass>* result) {
153 if (config.GetDescriptor()->full_name() !=
"tensorflow.serving.Config1") {
154 return errors::InvalidArgument(
"Wrong type of config proto: ",
155 config.GetDescriptor()->full_name());
157 auto* raw_result =
new AlternateSubClass();
158 result->reset(raw_result);
159 raw_result->config_ = config;
163 string class_name()
const override {
return "AlternateSubClass"; }
165 string config_data()
const override {
return config_.string_field(); }
170 REGISTER_MY_BASE_CLASS_USING_ALTERNATE_REGISTRY(AlternateSubClass, Config1);
172 TEST(ClassRegistrationTest, MultipleRegistriesForSameBaseClass) {
173 std::unique_ptr<MyBaseClass> loaded_subclass;
176 TF_ASSERT_OK(MyBaseClassRegistry::Create(config, &loaded_subclass));
177 EXPECT_EQ(
"SubClass1", loaded_subclass->class_name());
179 TF_ASSERT_OK(AlternateMyBaseClassRegistry::Create(config, &loaded_subclass));
180 EXPECT_EQ(
"AlternateSubClass", loaded_subclass->class_name());
185 class MyParameterizedBaseClass {
187 virtual ~MyParameterizedBaseClass() =
default;
188 virtual string class_name()
const = 0;
189 virtual string config_data()
const = 0;
190 virtual int param1_data()
const = 0;
191 virtual string param2_data()
const = 0;
192 virtual const std::map<string, int>& param3_data()
const = 0;
194 DEFINE_CLASS_REGISTRY(MyParameterizedBaseClassRegistry,
195 MyParameterizedBaseClass,
int,
const string&,
196 const std::map<string TFS_COMMA int>&);
197 #define REGISTER_MY_PARAMETERIZED_BASE_CLASS(SubClassCreator, ConfigProto) \
198 REGISTER_CLASS(MyParameterizedBaseClassRegistry, MyParameterizedBaseClass, \
199 SubClassCreator, ConfigProto, int, const string&, \
200 const std::map<string TFS_COMMA int>&);
204 class ParameterizedSubClass1 :
public MyParameterizedBaseClass {
206 static Status Create(
const Config1& config,
int param1,
const string& param2,
207 const std::map<string, int>& param3,
208 std::unique_ptr<MyParameterizedBaseClass>* result) {
209 if (config.GetDescriptor()->full_name() !=
"tensorflow.serving.Config1") {
210 return errors::InvalidArgument(
"Wrong type of config proto: ",
211 config.GetDescriptor()->full_name());
213 auto* raw_result =
new ParameterizedSubClass1();
214 result->reset(raw_result);
215 raw_result->config_ = config;
216 raw_result->param1_ = param1;
217 raw_result->param2_ = param2;
218 raw_result->param3_ = param3;
222 string class_name()
const override {
return "ParameterizedSubClass1"; }
224 string config_data()
const override {
return config_.string_field(); }
226 int param1_data()
const override {
return param1_; }
228 string param2_data()
const override {
return param2_; }
230 const std::map<string, int>& param3_data()
const override {
return param3_; }
236 std::map<string, int> param3_;
238 REGISTER_MY_PARAMETERIZED_BASE_CLASS(ParameterizedSubClass1, Config1);
240 TEST(ClassRegistrationTest, InstantiateParameterizedFromRawConfig) {
241 std::unique_ptr<MyParameterizedBaseClass> loaded_subclass;
244 config1.set_string_field(
"foo");
245 TF_ASSERT_OK(MyParameterizedBaseClassRegistry::Create(
246 config1, 42,
"bar", {{
"floop", 1}, {
"mrop", 2}}, &loaded_subclass));
247 EXPECT_EQ(
"ParameterizedSubClass1", loaded_subclass->class_name());
248 EXPECT_EQ(
"foo", loaded_subclass->config_data());
249 EXPECT_EQ(42, loaded_subclass->param1_data());
250 EXPECT_EQ(
"bar", loaded_subclass->param2_data());
251 EXPECT_THAT(loaded_subclass->param3_data(),
252 UnorderedElementsAre(Pair(
"floop", 1), Pair(
"mrop", 2)));
255 TEST(ClassRegistrationTest, InstantiateParameterizedFromAny) {
256 std::unique_ptr<MyParameterizedBaseClass> loaded_subclass;
259 config1.set_string_field(
"foo");
260 google::protobuf::Any any_config1;
261 any_config1.PackFrom(config1);
262 TF_ASSERT_OK(MyParameterizedBaseClassRegistry::CreateFromAny(
263 any_config1, 42,
"bar", {{
"floop", 1}, {
"mrop", 2}}, &loaded_subclass));
264 EXPECT_EQ(
"ParameterizedSubClass1", loaded_subclass->class_name());
265 EXPECT_EQ(
"foo", loaded_subclass->config_data());
266 EXPECT_EQ(42, loaded_subclass->param1_data());
267 EXPECT_EQ(
"bar", loaded_subclass->param2_data());
268 EXPECT_THAT(loaded_subclass->param3_data(),
269 UnorderedElementsAre(Pair(
"floop", 1), Pair(
"mrop", 2)));