16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.h"
19 #include <unordered_set>
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/cc/saved_model/tag_constants.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/protobuf/config.pb.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
34 #include "tensorflow_serving/servables/tensorflow/tflite_session.h"
35 #include "tensorflow_serving/session_bundle/session_bundle_util.h"
37 namespace tensorflow {
43 std::vector<SignatureDef> GetSignatureDefs(
const SavedModelBundle& bundle) {
44 std::vector<SignatureDef> signature_defs;
45 for (
const auto& entry : bundle.meta_graph_def.signature_def()) {
46 const SignatureDef& signature_def = entry.second;
47 signature_defs.push_back(signature_def);
49 return signature_defs;
53 const char kTfLiteModelFilename[] =
"model.tflite";
55 Status LoadTfLiteModel(
const string& model_dir, SavedModelBundle* bundle,
56 const SessionOptions& options,
int num_interpreter_pools,
57 int num_interpreters_per_pool) {
58 std::unique_ptr<TfLiteSession> session;
60 const string& fname = io::JoinPath(model_dir, kTfLiteModelFilename);
62 TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(fname, &size));
64 std::unique_ptr<RandomAccessFile> file;
65 TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(fname, &file));
68 model_bytes.resize(size);
70 TF_RETURN_IF_ERROR(file->Read(0, size, &sv, &model_bytes[0]));
72 std::unique_ptr<TfLiteSession> tflite_session;
73 TF_RETURN_IF_ERROR(TfLiteSession::Create(
74 std::move(model_bytes), options, num_interpreter_pools,
75 num_interpreters_per_pool, &tflite_session,
76 bundle->meta_graph_def.mutable_signature_def()));
77 bundle->session = std::move(tflite_session);
78 return absl::OkStatus();
81 bool TfLiteModelFound(
const string& model_dir) {
82 const string& fname = io::JoinPath(model_dir, kTfLiteModelFilename);
83 return Env::Default()->FilesExist({fname},
nullptr);
89 const SessionBundleConfig& config,
90 std::unique_ptr<SavedModelBundleFactory>* factory) {
91 std::shared_ptr<Batcher> batcher;
92 if (config.has_batching_parameters()) {
94 CreateBatchScheduler(config.batching_parameters(), &batcher));
97 return absl::OkStatus();
101 const string& path, ResourceAllocation* estimate)
const {
102 return EstimateResourceFromPath(
103 path, config_.resource_estimation_uses_validation_result(), estimate);
108 std::unique_ptr<SavedModelBundle>* bundle) {
109 return InternalCreateSavedModelBundle(metadata, path, bundle);
113 const string& path, std::unique_ptr<SavedModelBundle>* bundle) {
114 return InternalCreateSavedModelBundle({}, path, bundle);
117 Status SavedModelBundleFactory::InternalCreateSavedModelBundle(
118 const absl::optional<Loader::Metadata>& metadata,
const string& path,
119 std::unique_ptr<SavedModelBundle>* bundle) {
120 bundle->reset(
new SavedModelBundle);
121 std::unordered_set<string> saved_model_tags(
122 config_.saved_model_tags().begin(), config_.saved_model_tags().end());
125 if (saved_model_tags.empty()) {
126 saved_model_tags.insert(kSavedModelTagServe);
128 const auto& session_options = [&]() {
129 auto result = GetSessionOptions(config_);
130 string mixed_precision_value = config_.mixed_precision();
131 if (!mixed_precision_value.empty()) {
132 if (mixed_precision_value ==
"bfloat16") {
133 LOG(INFO) <<
"Running inference with bfloat16 auto mixed precision";
134 tensorflow::ConfigProto& config = result.config;
135 GraphOptions* gopt = config.mutable_graph_options();
136 RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
137 rwcfg->set_auto_mixed_precision_onednn_bfloat16(RewriterConfig::ON);
140 << config_.mixed_precision()
141 <<
" auto mixed precision is not supported. Valid option: bfloat16";
144 if (metadata.has_value()) {
145 auto* session_metadata =
146 result.config.mutable_experimental()->mutable_session_metadata();
147 session_metadata->set_name(metadata->servable_id.name);
148 session_metadata->set_version(metadata->servable_id.version);
153 bool is_tflite = config_.prefer_tflite_model() && TfLiteModelFound(path);
155 int num_tflite_pools = config_.num_tflite_pools();
156 if (num_tflite_pools == 0 && config_.num_tflite_interpreters() > 0) {
157 num_tflite_pools = config_.num_tflite_interpreters();
159 TF_RETURN_IF_ERROR(LoadTfLiteModel(
160 path, bundle->get(), session_options, num_tflite_pools,
161 config_.num_tflite_interpreters_per_pool()));
163 TF_RETURN_IF_ERROR(session_bundle::LoadSessionBundleOrSavedModelBundle(
164 session_options, GetRunOptions(config_), path, saved_model_tags,
165 config_.enable_saved_model_config(), bundle->get()));
167 if (config_.remove_unused_fields_from_bundle_metagraph()) {
174 MetaGraphDef metagraph;
175 (*bundle)->meta_graph_def.Swap(&metagraph);
176 (*bundle)->meta_graph_def.mutable_signature_def()->swap(
177 *metagraph.mutable_signature_def());
179 if (config_.wrap_session_with_no_threading_params()) {
180 return WrapSessionIgnoreThreadPoolOptions(&(*bundle)->session);
181 }
else if (config_.has_batching_parameters()) {
182 absl::optional<BatchingParameters> batching_params;
183 TF_RETURN_IF_ERROR(GetPerModelBatchingParams(
184 path, config_.batching_parameters(),
185 config_.enable_per_model_batching_params(), &batching_params));
186 if (batching_params.has_value()) {
190 const std::vector<SignatureDef> signatures = GetSignatureDefs(**bundle);
191 return WrapSessionForBatching(batching_params.value(), batch_scheduler_,
192 signatures, &(*bundle)->session);
195 return WrapSession(&(*bundle)->session);
198 SavedModelBundleFactory::SavedModelBundleFactory(
199 const SessionBundleConfig& config, std::shared_ptr<Batcher> batch_scheduler)
200 : config_(config), batch_scheduler_(batch_scheduler) {}
static Status Create(const SessionBundleConfig &config, std::unique_ptr< SavedModelBundleFactory > *factory)
Status EstimateResourceRequirement(const string &path, ResourceAllocation *estimate) const
Status CreateSavedModelBundleWithMetadata(const Loader::Metadata &metadata, const string &path, std::unique_ptr< SavedModelBundle > *bundle)
Status CreateSavedModelBundle(const string &path, std::unique_ptr< SavedModelBundle > *bundle)