16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h"
24 #include "google/protobuf/wrappers.pb.h"
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "tensorflow/cc/saved_model/loader.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow_serving/core/loader.h"
32 #include "tensorflow_serving/core/servable_data.h"
33 #include "tensorflow_serving/core/test_util/session_test_util.h"
34 #include "tensorflow_serving/resources/resource_util.h"
35 #include "tensorflow_serving/resources/resource_values.h"
36 #include "tensorflow_serving/resources/resources.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
38 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/test_util/test_util.h"
41 #include "tensorflow_serving/util/oss_or_google.h"
43 namespace tensorflow {
47 using test_util::EqualsProto;
49 Loader::Metadata CreateMetadata() {
return {ServableId{
"name", 42}}; }
51 class SavedModelBundleSourceAdapterTest
52 :
public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
54 SavedModelBundleSourceAdapterTest() {
55 ResourceUtil::Options resource_util_options;
56 resource_util_options.devices = {{device_types::kMain, 1}};
58 std::unique_ptr<ResourceUtil>(
new ResourceUtil(resource_util_options));
60 ram_resource_ = resource_util_->CreateBoundResource(
61 device_types::kMain, resource_kinds::kRamBytes);
62 config_.mutable_legacy_config()->set_enable_model_warmup(EnableWarmup());
63 if (EnableNumRequestIterations()) {
64 config_.mutable_legacy_config()
65 ->mutable_model_warmup_options()
66 ->mutable_num_request_iterations()
70 config_.mutable_legacy_config()->set_enable_session_metadata(
71 EnableSessionMetadata());
73 config_.mutable_legacy_config()->set_session_target(
74 test_util::kNewSessionHookSessionTargetPrefix);
75 test_util::SetNewSessionHook([&](
const SessionOptions& session_options) {
76 EXPECT_EQ(EnableSessionMetadata(),
77 session_options.config.experimental().has_session_metadata());
78 if (EnableSessionMetadata()) {
79 const auto& actual_session_metadata =
80 session_options.config.experimental().session_metadata();
81 const auto& expected_loader_metadata = CreateMetadata();
82 EXPECT_EQ(expected_loader_metadata.servable_id.name,
83 actual_session_metadata.name());
84 EXPECT_EQ(expected_loader_metadata.servable_id.version,
85 actual_session_metadata.version());
87 return absl::OkStatus();
91 void TestSavedModelBundleSourceAdapter(
const string& export_dir)
const {
92 std::unique_ptr<Loader> loader;
94 std::unique_ptr<SavedModelBundleSourceAdapter> adapter;
95 TF_CHECK_OK(SavedModelBundleSourceAdapter::Create(config_, &adapter));
96 ServableData<std::unique_ptr<Loader>> loader_data =
97 adapter->AdaptOneVersion(
98 ServableData<StoragePath>({
"", 0}, export_dir));
99 TF_ASSERT_OK(loader_data.status());
100 loader = loader_data.ConsumeDataOrDie();
108 ResourceAllocation first_resource_estimate;
109 TF_ASSERT_OK(loader->EstimateResources(&first_resource_estimate));
110 EXPECT_FALSE(first_resource_estimate.resource_quantities().empty());
111 ResourceAllocation second_resource_estimate;
112 TF_ASSERT_OK(loader->EstimateResources(&second_resource_estimate));
113 EXPECT_THAT(second_resource_estimate, EqualsProto(first_resource_estimate));
115 const auto metadata = CreateMetadata();
116 TF_ASSERT_OK(loader->LoadWithMetadata(CreateMetadata()));
119 ResourceAllocation expected_post_load_resource_estimate =
120 first_resource_estimate;
121 resource_util_->SetQuantity(
123 resource_util_->GetQuantity(ram_resource_, first_resource_estimate) -
124 config_.legacy_config()
125 .experimental_transient_ram_bytes_during_load(),
126 &expected_post_load_resource_estimate);
127 ResourceAllocation actual_post_load_resource_estimate;
129 loader->EstimateResources(&actual_post_load_resource_estimate));
130 EXPECT_THAT(actual_post_load_resource_estimate,
131 EqualsProto(expected_post_load_resource_estimate));
133 const SavedModelBundle* bundle = loader->servable().get<SavedModelBundle>();
134 test_util::TestSingleRequest(bundle->session.get());
139 bool EnableWarmup()
const {
return std::get<0>(GetParam()); }
140 bool EnableNumRequestIterations()
const {
return std::get<1>(GetParam()); }
141 bool EnableSessionMetadata()
const {
return std::get<2>(GetParam()); }
143 std::unique_ptr<ResourceUtil> resource_util_;
144 Resource ram_resource_;
145 SavedModelBundleSourceAdapterConfig config_;
148 TEST_P(SavedModelBundleSourceAdapterTest, Basic) {
149 config_.mutable_legacy_config()
150 ->set_experimental_transient_ram_bytes_during_load(42);
152 TestSavedModelBundleSourceAdapter(test_util::GetTestSavedModelPath());
155 TEST_P(SavedModelBundleSourceAdapterTest, BackwardCompatibility) {
156 if (IsTensorflowServingOSS()) {
159 TestSavedModelBundleSourceAdapter(
160 test_util::GetTestSessionBundleExportPath());
163 TEST_P(SavedModelBundleSourceAdapterTest, MLMetadata) {
164 if (!EnableSessionMetadata())
return;
165 TestSavedModelBundleSourceAdapter(test_util::TestSrcDirPath(
166 strings::StrCat(
"/servables/tensorflow/testdata/",
167 "saved_model_half_plus_two_mlmd/00000123")));
168 auto* collection_registry = monitoring::CollectionRegistry::Default();
169 monitoring::CollectionRegistry::CollectMetricsOptions options;
170 const std::unique_ptr<monitoring::CollectedMetrics> collected_metrics =
171 collection_registry->CollectMetrics(options);
172 const monitoring::PointSet& lps =
173 *collected_metrics->point_set_map.at(
"/tensorflow/serving/mlmd_map");
175 EXPECT_EQ(1, lps.points.size());
176 EXPECT_EQ(2, lps.points[0]->labels.size());
177 EXPECT_EQ(
"model_name", lps.points[0]->labels[0].name);
178 EXPECT_EQ(
"name", lps.points[0]->labels[0].value);
179 EXPECT_EQ(
"version", lps.points[0]->labels[1].name);
180 EXPECT_EQ(
"42", lps.points[0]->labels[1].value);
181 EXPECT_EQ(
"test_mlmd_uuid", lps.points[0]->string_value);
187 INSTANTIATE_TEST_CASE_P(VariousOptions, SavedModelBundleSourceAdapterTest,
188 ::testing::Combine(::testing::Bool(), ::testing::Bool(),