16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h"
21 #include "google/protobuf/wrappers.pb.h"
22 #include "tensorflow/cc/saved_model/constants.h"
23 #include "tensorflow/cc/saved_model/tag_constants.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/protobuf/meta_graph.pb.h"
27 #include "tensorflow/core/tfrt/utils/tensor_util.h"
28 #include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
29 #include "tensorflow_serving/session_bundle/graph_rewriter.h"
30 #include "tensorflow_serving/test_util/test_util.h"
32 namespace tensorflow {
36 Loader::Metadata CreateMetadata() {
return {ServableId{
"name", 42}}; }
39 class TfrtSavedModelFactoryTest :
public ::testing::Test {
41 static void SetUpTestSuite() {
42 tfrt_stub::SetGlobalRuntime(
43 tfrt_stub::Runtime::Create(4));
45 TfrtSavedModelFactoryTest()
46 : model_path_(test_util::TestSrcDirPath(
47 "servables/tensorflow/"
48 "testdata/saved_model_half_plus_two_cpu/00000123")) {}
50 Status CreateTfrtSavedModel(
const TfrtSavedModelConfig& config,
51 std::unique_ptr<tfrt::SavedModel>* saved_model) {
52 std::unique_ptr<TfrtSavedModelFactory> factory;
54 TF_RETURN_IF_ERROR(factory->CreateTfrtSavedModelWithMetadata(
55 CreateMetadata(), model_path_, saved_model));
56 return absl::OkStatus();
59 std::vector<string> GetModelFiles() {
60 const string& dir = model_path_;
61 return {io::JoinPath(dir, kSavedModelAssetsDirectory,
"foo.txt"),
62 io::JoinPath(dir, kSavedModelFilenamePb),
63 io::JoinPath(dir, kSavedModelVariablesFilename,
64 "variables.data-00000-of-00001"),
65 io::JoinPath(dir, kSavedModelVariablesFilename,
"variables.index")};
71 TEST_F(TfrtSavedModelFactoryTest, EstimateResourceRequirementWithGoodExport) {
72 TfrtSavedModelConfig config;
73 std::unique_ptr<TfrtSavedModelFactory> factory;
76 ResourceAllocation actual;
77 TF_ASSERT_OK(factory->EstimateResourceRequirement(model_path_, &actual));
79 const double total_file_size = test_util::GetTotalFileSize(GetModelFiles());
80 ResourceAllocation expected =
81 test_util::GetExpectedResourceEstimate(total_file_size);
82 EXPECT_THAT(actual, test_util::EqualsProto(expected));
85 TEST_F(TfrtSavedModelFactoryTest, Basic) {
86 std::unique_ptr<tfrt::SavedModel> saved_model;
87 TfrtSavedModelConfig config;
88 *config.add_saved_model_tags() = kSavedModelTagServe;
89 TF_ASSERT_OK(CreateTfrtSavedModel(config, &saved_model));
91 Tensor input_tensor = test::AsTensor<float>({100.0f, 42.0f}, {2});
92 Tensor expected_output =
93 test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
94 std::vector<tensorflow::Tensor> input_tensors;
95 input_tensors.push_back(input_tensor);
96 tfrt::SavedModel::RunOptions run_options;
97 std::vector<tensorflow::Tensor> outputs;
98 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default", input_tensors,
101 ASSERT_EQ(1, outputs.size());
102 const auto& single_output = outputs.at(0);
103 test::ExpectTensorEqual<float>(expected_output, single_output);
108 TEST_F(TfrtSavedModelFactoryTest, BasicWithSavedModelConfig) {
109 std::unique_ptr<tfrt::SavedModel> saved_model;
110 TfrtSavedModelConfig config;
111 *config.add_saved_model_tags() = kSavedModelTagServe;
112 model_path_ = test_util::TestSrcDirPath(
113 "servables/tensorflow/"
114 "testdata/saved_model_half_plus_two_cpu_with_saved_model_config/"
116 config.set_enable_saved_model_config(
true);
118 TF_ASSERT_OK(CreateTfrtSavedModel(config, &saved_model));
120 Tensor input_tensor = test::AsTensor<float>({100.0f, 42.0f}, {2});
121 Tensor expected_output =
122 test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
123 std::vector<tensorflow::Tensor> input_tensors;
124 input_tensors.push_back(input_tensor);
125 tfrt::SavedModel::RunOptions run_options;
126 std::vector<tensorflow::Tensor> outputs;
127 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default", input_tensors,
130 ASSERT_EQ(1, outputs.size());
131 const auto& single_output = outputs.at(0);
132 test::ExpectTensorEqual<float>(expected_output, single_output);
137 TEST_F(TfrtSavedModelFactoryTest, BasicWithSavedModelConfigAndGraphRewrite) {
138 TF_ASSERT_OK(tensorflow::serving::ResetGraphRewriterForTesting());
139 bool rewriter_was_called =
false;
140 TF_ASSERT_OK(tensorflow::serving::SetGraphRewriter([&](MetaGraphDef* graph) {
141 rewriter_was_called =
true;
142 return absl::OkStatus();
144 std::unique_ptr<tfrt::SavedModel> saved_model;
145 TfrtSavedModelConfig config;
146 *config.add_saved_model_tags() = kSavedModelTagServe;
147 model_path_ = test_util::TestSrcDirPath(
148 "servables/tensorflow/"
149 "testdata/saved_model_half_plus_two_cpu_with_saved_model_config/"
151 config.set_enable_saved_model_config(
true);
153 TF_ASSERT_OK(CreateTfrtSavedModel(config, &saved_model));
154 EXPECT_TRUE(rewriter_was_called);
155 TF_ASSERT_OK(tensorflow::serving::ResetGraphRewriterForTesting());
157 Tensor input_tensor = test::AsTensor<float>({100.0f, 42.0f}, {2});
158 Tensor expected_output =
159 test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
160 std::vector<tensorflow::Tensor> input_tensors;
161 input_tensors.push_back(input_tensor);
162 tfrt::SavedModel::RunOptions run_options;
163 std::vector<tensorflow::Tensor> outputs;
164 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default", input_tensors,
167 ASSERT_EQ(1, outputs.size());
168 const auto& single_output = outputs.at(0);
169 test::ExpectTensorEqual<float>(expected_output, single_output);
172 TEST_F(TfrtSavedModelFactoryTest, Batch) {
173 std::unique_ptr<tfrt::SavedModel> saved_model;
174 TfrtSavedModelConfig config;
175 config.mutable_batching_parameters()->mutable_max_batch_size()->set_value(4);
176 config.mutable_batching_parameters()
177 ->mutable_max_enqueued_batches()
178 ->set_value(INT_MAX);
179 config.mutable_batching_parameters()
180 ->mutable_batch_timeout_micros()
181 ->set_value(1000 * 1000 * 1000);
182 config.mutable_batching_parameters()->mutable_num_batch_threads()->set_value(
184 TF_ASSERT_OK(CreateTfrtSavedModel(config, &saved_model));
186 Tensor input_tensor = test::AsTensor<float>({100.0f, 42.0f}, {2});
187 Tensor expected_output =
188 test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
189 std::vector<tensorflow::Tensor> input_tensors;
190 input_tensors.push_back(input_tensor);
191 std::vector<tensorflow::Tensor> output_tensors1, output_tensors2;
192 tfrt::SavedModel::RunOptions run_options;
195 std::vector<std::unique_ptr<Thread>> request_threads;
196 request_threads.reserve(2);
197 request_threads.push_back(
198 std::unique_ptr<Thread>(Env::Default()->StartThread(
199 ThreadOptions(), strings::StrCat(
"thread_", 0),
200 [&saved_model, &run_options, &input_tensors, &output_tensors1]() {
201 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default",
202 input_tensors, &output_tensors1));
204 request_threads.push_back(
205 std::unique_ptr<Thread>(Env::Default()->StartThread(
206 ThreadOptions(), strings::StrCat(
"thread_", 1),
207 [&saved_model, &run_options, &input_tensors, &output_tensors2]() {
208 TF_ASSERT_OK(saved_model->Run(run_options,
"serving_default",
209 input_tensors, &output_tensors2));
213 ASSERT_EQ(1, output_tensors1.size());
214 test::ExpectTensorEqual<float>(expected_output, output_tensors1.at(0));
216 ASSERT_EQ(1, output_tensors2.size());
217 test::ExpectTensorEqual<float>(expected_output, output_tensors2.at(0));
static absl::Status Create(const TfrtSavedModelConfig &config, std::unique_ptr< TfrtSavedModelFactory > *factory)