16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.h"
18 #include "google/protobuf/wrappers.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/lib/monitoring/collection_registry.h"
22 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
23 #include "tensorflow/core/tfrt/utils/tensor_util.h"
24 #include "tensorflow_serving/core/loader.h"
25 #include "tensorflow_serving/resources/resource_util.h"
26 #include "tensorflow_serving/resources/resource_values.h"
27 #include "tensorflow_serving/resources/resources.pb.h"
28 #include "tensorflow_serving/servables/tensorflow/servable.h"
29 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
30 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
31 #include "tensorflow_serving/test_util/test_util.h"
33 namespace tensorflow {
37 using test_util::EqualsProto;
39 Loader::Metadata CreateMetadata() {
return {ServableId{
"name", 42}}; }
41 class TfrtSavedModelSourceAdapterTest
42 :
public ::testing::TestWithParam<std::tuple<bool, bool>> {
44 static void SetUpTestSuite() {
45 tfrt_stub::SetGlobalRuntime(
46 tfrt_stub::Runtime::Create(4));
49 TfrtSavedModelSourceAdapterTest() {
50 ResourceUtil::Options resource_util_options;
51 resource_util_options.devices = {{device_types::kMain, 1}};
53 std::unique_ptr<ResourceUtil>(
new ResourceUtil(resource_util_options));
55 ram_resource_ = resource_util_->CreateBoundResource(
56 device_types::kMain, resource_kinds::kRamBytes);
57 config_.mutable_saved_model_config()
58 ->mutable_legacy_config()
59 ->set_enable_model_warmup(EnableWarmup());
60 if (EnableNumRequestIterations()) {
61 config_.mutable_saved_model_config()
62 ->mutable_legacy_config()
63 ->mutable_model_warmup_options()
64 ->mutable_num_request_iterations()
68 config_.mutable_saved_model_config()
69 ->mutable_legacy_config()
70 ->set_enable_session_metadata(
true);
73 void TestTFRTSavedModelSourceAdapter(
const string& export_dir)
const {
74 std::unique_ptr<Loader> loader;
76 std::unique_ptr<TfrtSavedModelSourceAdapter> adapter;
77 TF_CHECK_OK(TfrtSavedModelSourceAdapter::Create(config_, &adapter));
78 ServableData<std::unique_ptr<Loader>> loader_data =
79 adapter->AdaptOneVersion(
80 ServableData<StoragePath>({
"", 0}, export_dir));
81 TF_ASSERT_OK(loader_data.status());
82 loader = loader_data.ConsumeDataOrDie();
90 ResourceAllocation first_resource_estimate;
91 TF_ASSERT_OK(loader->EstimateResources(&first_resource_estimate));
92 EXPECT_FALSE(first_resource_estimate.resource_quantities().empty());
93 ResourceAllocation second_resource_estimate;
94 TF_ASSERT_OK(loader->EstimateResources(&second_resource_estimate));
95 EXPECT_THAT(second_resource_estimate, EqualsProto(first_resource_estimate));
97 const auto metadata = CreateMetadata();
98 TF_ASSERT_OK(loader->LoadWithMetadata(CreateMetadata()));
101 ResourceAllocation expected_post_load_resource_estimate =
102 first_resource_estimate;
103 resource_util_->SetQuantity(
105 resource_util_->GetQuantity(ram_resource_, first_resource_estimate),
106 &expected_post_load_resource_estimate);
107 ResourceAllocation actual_post_load_resource_estimate;
109 loader->EstimateResources(&actual_post_load_resource_estimate));
110 EXPECT_THAT(actual_post_load_resource_estimate,
111 EqualsProto(expected_post_load_resource_estimate));
113 tfrt::SavedModel& saved_model =
114 down_cast<TfrtSavedModelServable*>(loader->servable().get<Servable>())
116 TestSingleRequest(&saved_model);
121 void TestSingleRequest(tfrt::SavedModel* saved_model)
const {
122 Tensor input = test::AsTensor<float>({100.0f, 42.0f}, {2});
124 Tensor expected_output =
125 test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
127 std::vector<tensorflow::Tensor> input_tensors;
128 input_tensors.push_back(input);
129 tfrt::SavedModel::RunOptions run_options;
130 std::vector<tensorflow::Tensor> output_tensors;
131 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default", input_tensors,
134 ASSERT_EQ(1, output_tensors.size());
135 const auto& single_output = output_tensors.at(0);
136 test::ExpectTensorEqual<float>(expected_output, single_output);
139 bool EnableWarmup()
const {
return std::get<0>(GetParam()); }
140 bool EnableNumRequestIterations()
const {
return std::get<1>(GetParam()); }
142 std::unique_ptr<ResourceUtil> resource_util_;
143 Resource ram_resource_;
144 TfrtSavedModelSourceAdapterConfig config_;
147 TEST_P(TfrtSavedModelSourceAdapterTest, Basic) {
148 TestTFRTSavedModelSourceAdapter(
149 test_util::TestSrcDirPath(
"servables/tensorflow/testdata/"
150 "saved_model_half_plus_two_cpu/00000123"));
153 TEST_P(TfrtSavedModelSourceAdapterTest, MLMetadata) {
154 TestTFRTSavedModelSourceAdapter(
155 test_util::TestSrcDirPath(
"servables/tensorflow/testdata/"
156 "saved_model_half_plus_two_mlmd/00000123"));
157 auto* collection_registry = monitoring::CollectionRegistry::Default();
158 monitoring::CollectionRegistry::CollectMetricsOptions options;
159 const std::unique_ptr<monitoring::CollectedMetrics> collected_metrics =
160 collection_registry->CollectMetrics(options);
161 const monitoring::PointSet& lps =
162 *collected_metrics->point_set_map.at(
"/tensorflow/serving/mlmd_map");
164 EXPECT_EQ(1, lps.points.size());
165 EXPECT_EQ(2, lps.points[0]->labels.size());
166 EXPECT_EQ(
"model_name", lps.points[0]->labels[0].name);
167 EXPECT_EQ(
"name", lps.points[0]->labels[0].value);
168 EXPECT_EQ(
"version", lps.points[0]->labels[1].name);
169 EXPECT_EQ(
"42", lps.points[0]->labels[1].value);
170 EXPECT_EQ(
"test_mlmd_uuid", lps.points[0]->string_value);
175 INSTANTIATE_TEST_CASE_P(VariousOptions, TfrtSavedModelSourceAdapterTest,
176 ::testing::Combine(::testing::Bool(),