16 #include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/monitoring/gauge.h"
25 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
26 #include "tensorflow_serving/test_util/test_util.h"
28 namespace tensorflow {
32 const char mlmd_streamz[] =
"/tensorflow/serving/mlmd_map";
34 bool GetMlmdUuid(
const string& model_name,
const string& version,
35 std::string* mlmd_uuid) {
36 auto* collection_registry = tsl::monitoring::CollectionRegistry::Default();
37 tsl::monitoring::CollectionRegistry::CollectMetricsOptions options;
38 const std::unique_ptr<tsl::monitoring::CollectedMetrics> collected_metrics =
39 collection_registry->CollectMetrics(options);
40 const auto& point_set_map = collected_metrics->point_set_map;
41 if (point_set_map.empty() ||
42 point_set_map.find(mlmd_streamz) == point_set_map.end())
44 const tsl::monitoring::PointSet& lps =
45 *collected_metrics->point_set_map.at(mlmd_streamz);
46 for (
int i = 0; i < lps.points.size(); ++i) {
47 if ((lps.points[i]->labels[0].name ==
"model_name") &&
48 (lps.points[i]->labels[0].value == model_name) &&
49 (lps.points[i]->labels[1].name ==
"version") &&
50 (lps.points[i]->labels[1].value == version)) {
51 *mlmd_uuid = lps.points[i]->string_value;
58 TEST(MachineLearningMetaDataTest, BasicTest_MLMD_missing) {
59 std::string mlmd_uuid;
60 ASSERT_FALSE(GetMlmdUuid(
"missing_model",
"9696", &mlmd_uuid));
61 string test_data_path = test_util::GetTestSavedModelPath();
62 MaybePublishMLMDStreamz(test_data_path,
"missing_model", 9696);
63 EXPECT_FALSE(GetMlmdUuid(
"missing_model",
"9696", &mlmd_uuid));
66 TEST(MachineLearningMetaDataTest, BasicTest_MLMD_present) {
67 std::string mlmd_uuid;
68 ASSERT_FALSE(GetMlmdUuid(
"test_model",
"9696", &mlmd_uuid));
69 const string test_data_path = test_util::TestSrcDirPath(
70 strings::StrCat(
"/servables/tensorflow/testdata/",
71 "saved_model_half_plus_two_mlmd/00000123"));
72 MaybePublishMLMDStreamz(test_data_path,
"test_model", 9696);
73 EXPECT_TRUE(GetMlmdUuid(
"test_model",
"9696", &mlmd_uuid));
74 EXPECT_EQ(
"test_mlmd_uuid", mlmd_uuid);