TensorFlow Serving C++ API Documentation
class_registration.h
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A way to register subclasses of an abstract base class, and associate a
17 // config proto message type. Instances can be instantiated from an Any proto
18 // field that contains a config proto, based on the type and content of the
19 // config proto.
20 //
21 // IMPORTANT: Config protos used in registries must be compiled into the binary.
22 //
23 // Each registry has a name. Registry names in each namespace must be distinct.
24 // A registry is tied to a specific base class and factory signature. It is fine
25 // to have multiple registries for a given base class, whether having the same
26 // factory signature or multiple distinct signatures.
27 //
28 // Usage:
29 //
30 // // Define a base class.
31 // class MyBaseClass {
32 // ...
33 // };
34 //
35 // // Define a registry that maps from proto message types to subclasses of
36 // // MyBaseClass.
37 // DEFINE_CLASS_REGISTRY(MyBaseClassRegistry, MyBaseClass);
38 //
39 // // Define a macro used create a specific entry in MyBaseClassRegistry that
40 // // maps from ConfigProto to ClassCreator::Create().
41 // #define REGISTER_MY_BASE_CLASS(ClassCreator, ConfigProto)
42 // REGISTER_CLASS(MyBaseClassRegistry, MyBaseClass, ClassCreator,
43 // ConfigProto);
44 //
45 // // Declare a subclass of MyBaseClass to be created when a OneConfigProto
46 // // is passed to the Create*() factory methods.
47 // class OneClass : public MyBaseClass {
48 // public:
49 // static Status Create(const OneConfigProto& config,
50 // std::unique_ptr<BaseClass>* result) {
51 // OneClass* raw_result = new OneClass();
52 // raw_result->config_ = config;
53 // Status status = raw_result->Init();
54 // if (status.ok()) {
55 // result->reset(raw_result);
56 // }
57 // return status;
58 // }
59 //
60 // private:
61 // Status Init() {
62 // ... initialize the object based on 'config_'
63 // }
64 //
65 // OneConfigProto config_;
66 // };
67 // REGISTER_MY_BASE_CLASS(OneClass, OneConfigProto);
68 //
69 // // Create an object of type OneClass using the registry to switch on
70 // // the type OneConfigProto.
71 // OneConfigProto config = ...
72 // std::unique_ptr<BaseClass> loaded_subclass;
73 // CHECK_OK(MyBaseClassRegistry::Create(config, &loaded_subclass));
74 //
75 // // Same, but starting from an Any message that wraps a OneConfigProto.
76 // protobuf::Any any_config = ... // wraps a OneConfigProto
77 // std::unique_ptr<BaseClass> loaded_subclass;
78 // CHECK_OK(MyBaseClassRegistry::CreateFromAny(any_config, &loaded_subclass));
79 //
80 // Note that the subclass creator need not be the subclass itself. For example:
81 //
82 // class AnotherClass : public MyBaseClass {
83 // public:
84 // AnotherClass(int a, int b);
85 // ...
86 // };
87 //
88 // class CreatorForAnotherClass {
89 // public:
90 // static Status Create(const OneConfigProto& config,
91 // std::unique_ptr<BaseClass>* result) {
92 // result->reset(new AnotherClass(config.a(), config.b()));
93 // return Status::OK;
94 // }
95 // };
96 //
97 // REGISTER_MY_BASE_CLASS(CreatorForAnotherClass, OneConfigProto);
98 //
99 //
100 // This mechanism also allows additional parameter passing into the Create()
101 // factory method. Consider the following example in which Create() takes an
102 // int, a string and an int->string map, in addition to the config proto:
103 //
104 // class MyParameterizedBaseClass {
105 // ...
106 // };
107 // DEFINE_CLASS_REGISTRY(MyParameterizedBaseClassRegistry,
108 // MyParameterizedBaseClass, int, const string&
109 // const std::map<int TFS_COMMA string>&);
110 // #define REGISTER_MY_BASE_CLASS(ClassCreator, ConfigProto)
111 // REGISTER_CLASS(MyBaseClassRegistry, MyBaseClass, ClassCreator,
112 // ConfigProto, int, const string&,
113 // const std::map<int TFS_COMMA string>&);
114 //
115 // class OneClass : public MyParameterizedBaseClass {
116 // public:
117 // static Status Create(const OneConfigProto& config,
118 // int param1, const string& param2,
119 // const std::map<int, string>& param3,
120 // std::unique_ptr<BaseClass>* result) {
121 // ...
122 // }
123 // ...
124 // };
125 //
126 // OneConfigProto config = ...
127 // int int_param = ...
128 // string string_param = ...
129 // std::map<int, string> map_param = ...
130 // std::unique_ptr<BaseClass> loaded_subclass;
131 // CHECK_OK(MyParameterizedBaseClassRegistry::Create(config,
132 // int_param,
133 // string_param,
134 // map_param,
135 // &loaded_subclass));
136 //
137 // The registry name can be anything you choose, and it's fine to have multiple
138 // registries for the same base class, potentially with different factory
139 // signatures.
140 //
141 // Note that types containing a comma, e.g. std::map<string, int> must use
142 // TFS_COMMA in place of ','.
143 // TODO(b/24472377): Eliminate the requirement to use TFS_COMMA.
144 
145 #ifndef TENSORFLOW_SERVING_UTIL_CLASS_REGISTRATION_H_
146 #define TENSORFLOW_SERVING_UTIL_CLASS_REGISTRATION_H_
147 
148 #include <algorithm>
149 #include <memory>
150 #include <string>
151 #include <unordered_map>
152 
153 #include "google/protobuf/any.pb.h"
154 #include "google/protobuf/descriptor.h"
155 #include "google/protobuf/message.h"
156 #include "tensorflow/core/lib/core/errors.h"
157 #include "tensorflow/core/lib/core/status.h"
158 #include "tensorflow/core/lib/core/stringpiece.h"
159 #include "tensorflow/core/lib/strings/strcat.h"
160 #include "tensorflow/core/platform/macros.h"
161 #include "tensorflow/core/platform/mutex.h"
162 #include "tensorflow/core/platform/protobuf.h"
163 #include "tensorflow/core/platform/thread_annotations.h"
164 #include "tensorflow_serving/util/class_registration_util.h"
165 
166 namespace tensorflow {
167 namespace serving {
168 namespace internal {
169 
170 // The interface for a factory method that takes a protobuf::Message as
171 // input to construct an object of type BaseClass.
172 template <typename BaseClass, typename... AdditionalFactoryArgs>
174  public:
175  virtual ~AbstractClassRegistrationFactory() = default;
176 
177  // Creates an object using this factory. Fails if 'config' is not of the
178  // expected type.
179  virtual Status Create(const protobuf::Message& config,
180  AdditionalFactoryArgs... args,
181  std::unique_ptr<BaseClass>* result) const = 0;
182 };
183 
184 // The interface for a factory method that takes a protobuf::Message as
185 // input to construct an object of type BaseClass.
186 template <typename BaseClass, typename Class, typename Config,
187  typename... AdditionalFactoryArgs>
189  : public AbstractClassRegistrationFactory<BaseClass,
190  AdditionalFactoryArgs...> {
191  public:
193  : config_descriptor_(Config::default_instance().GetDescriptor()) {}
194 
195  // Creates an object using this factory. Fails if 'config' is not of the
196  // expected type.
197  Status Create(const protobuf::Message& config, AdditionalFactoryArgs... args,
198  std::unique_ptr<BaseClass>* result) const override {
199  if (config.GetDescriptor()->full_name() !=
200  config_descriptor_->full_name()) {
201  return errors::InvalidArgument(
202  "Supplied config proto of type ", config.GetDescriptor()->full_name(),
203  " does not match expected type ", config_descriptor_->full_name());
204  }
205  return Class::Create(static_cast<const Config&>(config),
206  std::forward<AdditionalFactoryArgs>(args)..., result);
207  }
208 
209  private:
210  const protobuf::Descriptor* const config_descriptor_;
211 
212  TF_DISALLOW_COPY_AND_ASSIGN(ClassRegistrationFactory);
213 };
214 
215 constexpr char kTypeGoogleApisComPrefix[] = "type.googleapis.com/";
216 
217 // A static map whose keys are proto message names, and values are
218 // ClassRegistrationFactories. Includes a Create() factory method that
219 // performs a lookup in the map and calls Create() on the
220 // ClassRegistrationFactory it finds.
221 template <typename RegistryName, typename BaseClass,
222  typename... AdditionalFactoryArgs>
224  public:
225  using FactoryType =
226  AbstractClassRegistrationFactory<BaseClass, AdditionalFactoryArgs...>;
227 
228  // Creates an instance of BaseClass based on a config proto.
229  static Status Create(const protobuf::Message& config,
230  AdditionalFactoryArgs... args,
231  std::unique_ptr<BaseClass>* result) {
232  const string& config_proto_message_type =
233  config.GetDescriptor()->full_name();
234  auto* factory = LookupFromMap(config_proto_message_type);
235  if (factory == nullptr) {
236  return errors::InvalidArgument(
237  "Couldn't find factory for config proto message type ",
238  config_proto_message_type, "\nconfig=", config.DebugString());
239  }
240  return factory->Create(config, std::forward<AdditionalFactoryArgs>(args)...,
241  result);
242  }
243 
244  // Creates an instance of BaseClass based on a config proto embedded in an Any
245  // message.
246  //
247  // Requires that the config proto in the Any has a compiled-in descriptor.
248  static Status CreateFromAny(const google::protobuf::Any& any_config,
249  AdditionalFactoryArgs... args,
250  std::unique_ptr<BaseClass>* result) {
251  // Copy the config to a proto message of the indicated type.
252  string full_type_name;
253  Status parse_status =
254  ParseUrlForAnyType(any_config.type_url(), &full_type_name);
255  if (!parse_status.ok()) {
256  return parse_status;
257  }
258  const protobuf::Descriptor* descriptor =
259  protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
260  full_type_name);
261  if (descriptor == nullptr) {
262  return errors::Internal(
263  "Unable to find compiled-in proto descriptor of type ",
264  full_type_name);
265  }
266  std::unique_ptr<protobuf::Message> config(
267  protobuf::MessageFactory::generated_factory()
268  ->GetPrototype(descriptor)
269  ->New());
270  if (!any_config.UnpackTo(config.get())) {
271  return errors::InvalidArgument("Malformed content of Any: ",
272  any_config.DebugString());
273  }
274  return Create(*config, std::forward<AdditionalFactoryArgs>(args)...,
275  result);
276  }
277 
278  // Nested class whose instantiation inserts a key/value pair into the factory
279  // map.
280  class MapInserter {
281  public:
282  MapInserter(const string& config_proto_message_type, FactoryType* factory) {
283  InsertIntoMap(config_proto_message_type, factory);
284  }
285  };
286 
287  private:
288  // Inserts a key/value pair into the factory map.
289  static void InsertIntoMap(const string& config_proto_message_type,
290  FactoryType* factory) {
291  LockableFactoryMap* global_map = GlobalFactoryMap();
292  {
293  mutex_lock lock(global_map->mu);
294  global_map->factory_map.insert({config_proto_message_type, factory});
295  }
296  }
297 
298  // Retrieves a value from the factory map, or returns nullptr if no value was
299  // found.
300  static FactoryType* LookupFromMap(const string& config_proto_message_type) {
301  LockableFactoryMap* global_map = GlobalFactoryMap();
302  {
303  mutex_lock lock(global_map->mu);
304  auto it = global_map->factory_map.find(config_proto_message_type);
305  if (it == global_map->factory_map.end()) {
306  return nullptr;
307  }
308  return it->second;
309  }
310  }
311 
312  // A map from proto descriptor names to factories, with a lock.
313  struct LockableFactoryMap {
314  mutex mu;
315  std::unordered_map<string, FactoryType*> factory_map TF_GUARDED_BY(mu);
316  };
317 
318  // Returns a pointer to the factory map. There is one factory map per set of
319  // template parameters of this class.
320  static LockableFactoryMap* GlobalFactoryMap() {
321  static auto* global_map = []() -> LockableFactoryMap* {
322  return new LockableFactoryMap;
323  }();
324  return global_map;
325  }
326 
327  TF_DISALLOW_COPY_AND_ASSIGN(ClassRegistry);
328 };
329 
330 } // namespace internal
331 
332 // Used to enable the following macros to work with types containing commas
333 // (e.g. map<string, int>).
334 // TODO(b/24472377): Eliminate the requirement to use TFS_COMMA, via some fancy
335 // macro gymnastics.
336 #define TFS_COMMA ,
337 
338 // Given a base class BaseClass, creates a registry named RegistryName.
339 #define DEFINE_CLASS_REGISTRY(RegistryName, BaseClass, ...) \
340  class RegistryName : public ::tensorflow::serving::internal::ClassRegistry< \
341  RegistryName, BaseClass, ##__VA_ARGS__> {};
342 
343 // Registers a factory that creates subclasses of BaseClass by calling
344 // ClassCreator::Create().
345 #define REGISTER_CLASS(RegistryName, BaseClass, ClassCreator, config_proto, \
346  ...) \
347  REGISTER_CLASS_UNIQ_HELPER(__COUNTER__, RegistryName, BaseClass, \
348  ClassCreator, config_proto, ##__VA_ARGS__)
349 
350 #define REGISTER_CLASS_UNIQ_HELPER(cnt, RegistryName, BaseClass, ClassCreator, \
351  config_proto, ...) \
352  REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
353  config_proto, ##__VA_ARGS__)
354 
355 #define REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
356  config_proto, ...) \
357  static ::tensorflow::serving::internal::ClassRegistry< \
358  RegistryName, BaseClass, ##__VA_ARGS__>::MapInserter \
359  register_class_##cnt( \
360  (config_proto::default_instance().GetDescriptor()->full_name()), \
361  (new ::tensorflow::serving::internal::ClassRegistrationFactory< \
362  BaseClass, ClassCreator, config_proto, ##__VA_ARGS__>));
363 
364 } // namespace serving
365 } // namespace tensorflow
366 
367 #endif // TENSORFLOW_SERVING_UTIL_CLASS_REGISTRATION_H_