145 #ifndef TENSORFLOW_SERVING_UTIL_CLASS_REGISTRATION_H_
146 #define TENSORFLOW_SERVING_UTIL_CLASS_REGISTRATION_H_
151 #include <unordered_map>
153 #include "google/protobuf/any.pb.h"
154 #include "google/protobuf/descriptor.h"
155 #include "google/protobuf/message.h"
156 #include "tensorflow/core/lib/core/errors.h"
157 #include "tensorflow/core/lib/core/status.h"
158 #include "tensorflow/core/lib/core/stringpiece.h"
159 #include "tensorflow/core/lib/strings/strcat.h"
160 #include "tensorflow/core/platform/macros.h"
161 #include "tensorflow/core/platform/mutex.h"
162 #include "tensorflow/core/platform/protobuf.h"
163 #include "tensorflow/core/platform/thread_annotations.h"
164 #include "tensorflow_serving/util/class_registration_util.h"
166 namespace tensorflow {
172 template <
typename BaseClass,
typename... AdditionalFactoryArgs>
179 virtual Status Create(
const protobuf::Message& config,
180 AdditionalFactoryArgs... args,
181 std::unique_ptr<BaseClass>* result)
const = 0;
186 template <
typename BaseClass,
typename Class,
typename Config,
187 typename... AdditionalFactoryArgs>
190 AdditionalFactoryArgs...> {
193 : config_descriptor_(Config::default_instance().GetDescriptor()) {}
197 Status Create(
const protobuf::Message& config, AdditionalFactoryArgs... args,
198 std::unique_ptr<BaseClass>* result)
const override {
199 if (config.GetDescriptor()->full_name() !=
200 config_descriptor_->full_name()) {
201 return errors::InvalidArgument(
202 "Supplied config proto of type ", config.GetDescriptor()->full_name(),
203 " does not match expected type ", config_descriptor_->full_name());
205 return Class::Create(
static_cast<const Config&
>(config),
206 std::forward<AdditionalFactoryArgs>(args)..., result);
210 const protobuf::Descriptor*
const config_descriptor_;
215 constexpr
char kTypeGoogleApisComPrefix[] =
"type.googleapis.com/";
221 template <
typename RegistryName,
typename BaseClass,
222 typename... AdditionalFactoryArgs>
229 static Status Create(
const protobuf::Message& config,
230 AdditionalFactoryArgs... args,
231 std::unique_ptr<BaseClass>* result) {
232 const string& config_proto_message_type =
233 config.GetDescriptor()->full_name();
234 auto* factory = LookupFromMap(config_proto_message_type);
235 if (factory ==
nullptr) {
236 return errors::InvalidArgument(
237 "Couldn't find factory for config proto message type ",
238 config_proto_message_type,
"\nconfig=", config.DebugString());
240 return factory->Create(config, std::forward<AdditionalFactoryArgs>(args)...,
248 static Status CreateFromAny(
const google::protobuf::Any& any_config,
249 AdditionalFactoryArgs... args,
250 std::unique_ptr<BaseClass>* result) {
252 string full_type_name;
253 Status parse_status =
254 ParseUrlForAnyType(any_config.type_url(), &full_type_name);
255 if (!parse_status.ok()) {
258 const protobuf::Descriptor* descriptor =
259 protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
261 if (descriptor ==
nullptr) {
262 return errors::Internal(
263 "Unable to find compiled-in proto descriptor of type ",
266 std::unique_ptr<protobuf::Message> config(
267 protobuf::MessageFactory::generated_factory()
268 ->GetPrototype(descriptor)
270 if (!any_config.UnpackTo(config.get())) {
271 return errors::InvalidArgument(
"Malformed content of Any: ",
272 any_config.DebugString());
274 return Create(*config, std::forward<AdditionalFactoryArgs>(args)...,
283 InsertIntoMap(config_proto_message_type, factory);
289 static void InsertIntoMap(
const string& config_proto_message_type,
291 LockableFactoryMap* global_map = GlobalFactoryMap();
293 mutex_lock lock(global_map->mu);
294 global_map->factory_map.insert({config_proto_message_type, factory});
300 static FactoryType* LookupFromMap(
const string& config_proto_message_type) {
301 LockableFactoryMap* global_map = GlobalFactoryMap();
303 mutex_lock lock(global_map->mu);
304 auto it = global_map->factory_map.find(config_proto_message_type);
305 if (it == global_map->factory_map.end()) {
313 struct LockableFactoryMap {
315 std::unordered_map<string, FactoryType*> factory_map TF_GUARDED_BY(mu);
320 static LockableFactoryMap* GlobalFactoryMap() {
321 static auto* global_map = []() -> LockableFactoryMap* {
322 return new LockableFactoryMap;
327 TF_DISALLOW_COPY_AND_ASSIGN(ClassRegistry);
339 #define DEFINE_CLASS_REGISTRY(RegistryName, BaseClass, ...) \
340 class RegistryName : public ::tensorflow::serving::internal::ClassRegistry< \
341 RegistryName, BaseClass, ##__VA_ARGS__> {};
345 #define REGISTER_CLASS(RegistryName, BaseClass, ClassCreator, config_proto, \
347 REGISTER_CLASS_UNIQ_HELPER(__COUNTER__, RegistryName, BaseClass, \
348 ClassCreator, config_proto, ##__VA_ARGS__)
350 #define REGISTER_CLASS_UNIQ_HELPER(cnt, RegistryName, BaseClass, ClassCreator, \
352 REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
353 config_proto, ##__VA_ARGS__)
355 #define REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
357 static ::tensorflow::serving::internal::ClassRegistry< \
358 RegistryName, BaseClass, ##__VA_ARGS__>::MapInserter \
359 register_class_##cnt( \
360 (config_proto::default_instance().GetDescriptor()->full_name()), \
361 (new ::tensorflow::serving::internal::ClassRegistrationFactory< \
362 BaseClass, ClassCreator, config_proto, ##__VA_ARGS__>));