16 #include "tensorflow_serving/servables/hashmap/hashmap_source_adapter.h"
19 #include <unordered_map>
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow_serving/core/loader.h"
33 #include "tensorflow_serving/core/servable_data.h"
34 #include "tensorflow_serving/servables/hashmap/hashmap_source_adapter.pb.h"
35 #include "tensorflow_serving/util/any_ptr.h"
37 using ::testing::Pair;
38 using ::testing::UnorderedElementsAre;
40 namespace tensorflow {
44 using Hashmap = std::unordered_map<string, string>;
47 Status WriteHashmapToFile(
const HashmapSourceAdapterConfig::Format format,
48 const string& file_name,
const Hashmap& hashmap) {
49 std::unique_ptr<WritableFile> file;
50 TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(file_name, &file));
52 case HashmapSourceAdapterConfig::SIMPLE_CSV: {
53 for (
const auto& entry : hashmap) {
54 const string& key = entry.first;
55 const string& value = entry.second;
56 const string line = strings::StrCat(key,
",", value,
"\n");
57 TF_RETURN_IF_ERROR(file->Append(line));
62 return errors::InvalidArgument(
"Unrecognized format enum value: ",
65 TF_RETURN_IF_ERROR(file->Close());
69 TEST(HashmapSourceAdapter, Basic) {
70 const auto format = HashmapSourceAdapterConfig::SIMPLE_CSV;
71 const string file = io::JoinPath(testing::TmpDir(),
"Basic");
73 WriteHashmapToFile(format, file, {{
"a",
"apple"}, {
"b",
"banana"}}));
75 HashmapSourceAdapterConfig config;
76 config.set_format(format);
78 std::unique_ptr<HashmapSourceAdapter>(
new HashmapSourceAdapter(config));
79 ServableData<std::unique_ptr<Loader>> loader_data =
80 adapter->AdaptOneVersion({{
"", 0}, file});
81 TF_ASSERT_OK(loader_data.status());
82 std::unique_ptr<Loader> loader = loader_data.ConsumeDataOrDie();
84 TF_ASSERT_OK(loader->Load());
86 const Hashmap* hashmap = loader->servable().get<Hashmap>();
88 UnorderedElementsAre(Pair(
"a",
"apple"), Pair(
"b",
"banana")));