16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/platform/types.h"
22 #include "tsl/platform/errors.h"
23 #include "tensorflow_serving/core/simple_loader.h"
24 #include "tensorflow_serving/resources/resource_util.h"
25 #include "tensorflow_serving/resources/resource_values.h"
26 #include "tensorflow_serving/resources/resources.pb.h"
27 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
28 #include "tensorflow_serving/servables/tensorflow/file_acl.h"
29 #include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
30 #include "tensorflow_serving/servables/tensorflow/servable.h"
31 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h"
32 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
34 namespace tensorflow {
37 Status TfrtSavedModelSourceAdapter::Create(
38 const TfrtSavedModelSourceAdapterConfig& config,
39 std::unique_ptr<TfrtSavedModelSourceAdapter>* adapter) {
40 std::unique_ptr<TfrtSavedModelFactory> factory;
43 adapter->reset(
new TfrtSavedModelSourceAdapter(std::move(factory)));
44 return absl::OkStatus();
47 TfrtSavedModelSourceAdapter::~TfrtSavedModelSourceAdapter() { Detach(); }
49 TfrtSavedModelSourceAdapter::TfrtSavedModelSourceAdapter(
50 std::unique_ptr<TfrtSavedModelFactory> factory)
51 : factory_(std::move(factory)) {}
53 SimpleLoader<Servable>::CreatorVariant
54 TfrtSavedModelSourceAdapter::GetServableCreator(
55 std::shared_ptr<TfrtSavedModelFactory> factory,
56 const StoragePath& path)
const {
57 return [factory, path](
const Loader::Metadata& metadata,
58 std::unique_ptr<Servable>* servable) {
59 TF_RETURN_IF_ERROR(RegisterModelRoot(metadata.servable_id, path));
61 factory->CreateTfrtSavedModelWithMetadata(metadata, path, servable));
62 return absl::OkStatus();
66 Status TfrtSavedModelSourceAdapter::Convert(
const StoragePath& path,
67 std::unique_ptr<Loader>* loader) {
68 std::shared_ptr<TfrtSavedModelFactory> factory = factory_;
69 auto servable_creator = GetServableCreator(factory, path);
70 auto resource_estimator = [factory, path](ResourceAllocation* estimate) {
71 TF_RETURN_IF_ERROR(factory->EstimateResourceRequirement(path, estimate));
73 ResourceUtil::Options resource_util_options;
74 resource_util_options.devices = {{device_types::kMain, 1}};
75 std::unique_ptr<ResourceUtil> resource_util =
76 std::unique_ptr<ResourceUtil>(
new ResourceUtil(resource_util_options));
77 const Resource ram_resource = resource_util->CreateBoundResource(
78 device_types::kMain, resource_kinds::kRamBytes);
79 resource_util->SetQuantity(
80 ram_resource, resource_util->GetQuantity(ram_resource, *estimate),
83 return absl::OkStatus();
85 auto post_load_resource_estimator = [factory,
86 path](ResourceAllocation* estimate) {
87 return factory->EstimateResourceRequirement(path, estimate);
89 loader->reset(
new SimpleLoader<Servable>(servable_creator, resource_estimator,
90 {post_load_resource_estimator}));
91 return absl::OkStatus();
98 const TfrtSavedModelSourceAdapterConfig& config,
99 std::unique_ptr<
SourceAdapter<StoragePath, std::unique_ptr<Loader>>>*
101 std::unique_ptr<TfrtSavedModelFactory> factory;
103 TfrtSavedModelFactory::Create(config.saved_model_config(), &factory));
105 return absl::OkStatus();
109 TfrtSavedModelSourceAdapterConfig);
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)