TensorFlow Serving C++ API Documentation
saved_model_bundle_source_adapter_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23 
24 #include "google/protobuf/wrappers.pb.h"
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include "tensorflow/cc/saved_model/loader.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow_serving/core/loader.h"
32 #include "tensorflow_serving/core/servable_data.h"
33 #include "tensorflow_serving/core/test_util/session_test_util.h"
34 #include "tensorflow_serving/resources/resource_util.h"
35 #include "tensorflow_serving/resources/resource_values.h"
36 #include "tensorflow_serving/resources/resources.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
38 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
39 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
40 #include "tensorflow_serving/test_util/test_util.h"
41 #include "tensorflow_serving/util/oss_or_google.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 namespace {
46 
47 using test_util::EqualsProto;
48 
49 Loader::Metadata CreateMetadata() { return {ServableId{"name", 42}}; }
50 
51 class SavedModelBundleSourceAdapterTest
52  : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
53  protected:
54  SavedModelBundleSourceAdapterTest() {
55  ResourceUtil::Options resource_util_options;
56  resource_util_options.devices = {{device_types::kMain, 1}};
57  resource_util_ =
58  std::unique_ptr<ResourceUtil>(new ResourceUtil(resource_util_options));
59 
60  ram_resource_ = resource_util_->CreateBoundResource(
61  device_types::kMain, resource_kinds::kRamBytes);
62  config_.mutable_legacy_config()->set_enable_model_warmup(EnableWarmup());
63  if (EnableNumRequestIterations()) {
64  config_.mutable_legacy_config()
65  ->mutable_model_warmup_options()
66  ->mutable_num_request_iterations()
67  ->set_value(2);
68  }
69 
70  config_.mutable_legacy_config()->set_enable_session_metadata(
71  EnableSessionMetadata());
72 
73  config_.mutable_legacy_config()->set_session_target(
74  test_util::kNewSessionHookSessionTargetPrefix);
75  test_util::SetNewSessionHook([&](const SessionOptions& session_options) {
76  EXPECT_EQ(EnableSessionMetadata(),
77  session_options.config.experimental().has_session_metadata());
78  if (EnableSessionMetadata()) {
79  const auto& actual_session_metadata =
80  session_options.config.experimental().session_metadata();
81  const auto& expected_loader_metadata = CreateMetadata();
82  EXPECT_EQ(expected_loader_metadata.servable_id.name,
83  actual_session_metadata.name());
84  EXPECT_EQ(expected_loader_metadata.servable_id.version,
85  actual_session_metadata.version());
86  }
87  return absl::OkStatus();
88  });
89  }
90 
91  void TestSavedModelBundleSourceAdapter(const string& export_dir) const {
92  std::unique_ptr<Loader> loader;
93  {
94  std::unique_ptr<SavedModelBundleSourceAdapter> adapter;
95  TF_CHECK_OK(SavedModelBundleSourceAdapter::Create(config_, &adapter));
96  ServableData<std::unique_ptr<Loader>> loader_data =
97  adapter->AdaptOneVersion(
98  ServableData<StoragePath>({"", 0}, export_dir));
99  TF_ASSERT_OK(loader_data.status());
100  loader = loader_data.ConsumeDataOrDie();
101 
102  // Let the adapter fall out of scope and be deleted. The loader we got
103  // from it should be unaffected. Regression test coverage for b/30202207.
104  }
105 
106  // We should get a non-empty resource estimate, and we should get the same
107  // value twice (via memoization).
108  ResourceAllocation first_resource_estimate;
109  TF_ASSERT_OK(loader->EstimateResources(&first_resource_estimate));
110  EXPECT_FALSE(first_resource_estimate.resource_quantities().empty());
111  ResourceAllocation second_resource_estimate;
112  TF_ASSERT_OK(loader->EstimateResources(&second_resource_estimate));
113  EXPECT_THAT(second_resource_estimate, EqualsProto(first_resource_estimate));
114 
115  const auto metadata = CreateMetadata();
116  TF_ASSERT_OK(loader->LoadWithMetadata(CreateMetadata()));
117 
118  // We should get a new (lower) resource estimate post-load.
119  ResourceAllocation expected_post_load_resource_estimate =
120  first_resource_estimate;
121  resource_util_->SetQuantity(
122  ram_resource_,
123  resource_util_->GetQuantity(ram_resource_, first_resource_estimate) -
124  config_.legacy_config()
125  .experimental_transient_ram_bytes_during_load(),
126  &expected_post_load_resource_estimate);
127  ResourceAllocation actual_post_load_resource_estimate;
128  TF_ASSERT_OK(
129  loader->EstimateResources(&actual_post_load_resource_estimate));
130  EXPECT_THAT(actual_post_load_resource_estimate,
131  EqualsProto(expected_post_load_resource_estimate));
132 
133  const SavedModelBundle* bundle = loader->servable().get<SavedModelBundle>();
134  test_util::TestSingleRequest(bundle->session.get());
135 
136  loader->Unload();
137  }
138 
139  bool EnableWarmup() const { return std::get<0>(GetParam()); }
140  bool EnableNumRequestIterations() const { return std::get<1>(GetParam()); }
141  bool EnableSessionMetadata() const { return std::get<2>(GetParam()); }
142 
143  std::unique_ptr<ResourceUtil> resource_util_;
144  Resource ram_resource_;
145  SavedModelBundleSourceAdapterConfig config_;
146 };
147 
148 TEST_P(SavedModelBundleSourceAdapterTest, Basic) {
149  config_.mutable_legacy_config()
150  ->set_experimental_transient_ram_bytes_during_load(42);
151 
152  TestSavedModelBundleSourceAdapter(test_util::GetTestSavedModelPath());
153 }
154 
155 TEST_P(SavedModelBundleSourceAdapterTest, BackwardCompatibility) {
156  if (IsTensorflowServingOSS()) {
157  return;
158  }
159  TestSavedModelBundleSourceAdapter(
160  test_util::GetTestSessionBundleExportPath());
161 }
162 
163 TEST_P(SavedModelBundleSourceAdapterTest, MLMetadata) {
164  if (!EnableSessionMetadata()) return;
165  TestSavedModelBundleSourceAdapter(test_util::TestSrcDirPath(
166  strings::StrCat("/servables/tensorflow/testdata/",
167  "saved_model_half_plus_two_mlmd/00000123")));
168  auto* collection_registry = monitoring::CollectionRegistry::Default();
169  monitoring::CollectionRegistry::CollectMetricsOptions options;
170  const std::unique_ptr<monitoring::CollectedMetrics> collected_metrics =
171  collection_registry->CollectMetrics(options);
172  const monitoring::PointSet& lps =
173  *collected_metrics->point_set_map.at("/tensorflow/serving/mlmd_map");
174 
175  EXPECT_EQ(1, lps.points.size());
176  EXPECT_EQ(2, lps.points[0]->labels.size());
177  EXPECT_EQ("model_name", lps.points[0]->labels[0].name);
178  EXPECT_EQ("name", lps.points[0]->labels[0].value);
179  EXPECT_EQ("version", lps.points[0]->labels[1].name);
180  EXPECT_EQ("42", lps.points[0]->labels[1].value);
181  EXPECT_EQ("test_mlmd_uuid", lps.points[0]->string_value);
182 }
183 
184 // Test all SavedModelBundleSourceAdapterTest test cases with
185 // warmup, num_request_iterations enabled/disabled and session-metadata
186 // enabled/disabled.
187 INSTANTIATE_TEST_CASE_P(VariousOptions, SavedModelBundleSourceAdapterTest,
188  ::testing::Combine(::testing::Bool(), ::testing::Bool(),
189  ::testing::Bool()));
190 
191 } // namespace
192 } // namespace serving
193 } // namespace tensorflow