16 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
18 #include "absl/status/status.h"
19 #include "tensorflow/cc/saved_model/loader.h"
20 #include "tensorflow/cc/saved_model/signature_constants.h"
21 #include "xla/tsl/lib/core/status_test_util.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
26 #include "tensorflow_serving/apis/classification.pb.h"
27 #include "tensorflow_serving/apis/input.pb.h"
28 #include "tensorflow_serving/apis/regression.pb.h"
29 #include "tensorflow_serving/core/availability_preserving_policy.h"
30 #include "tensorflow_serving/core/servable_handle.h"
31 #include "tensorflow_serving/model_servers/model_platform_types.h"
32 #include "tensorflow_serving/model_servers/platform_config_util.h"
33 #include "tensorflow_serving/model_servers/server_core.h"
34 #include "tensorflow_serving/servables/tensorflow/servable.h"
35 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
36 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
38 #include "tensorflow_serving/test_util/test_util.h"
40 namespace tensorflow {
44 constexpr
char kTestModelName[] =
"test_model";
45 constexpr
int kTestModelVersion = 123;
47 class TfrtMultiInferenceTest :
public ::testing::Test {
49 static void SetUpTestSuite() {
50 tfrt_stub::SetGlobalRuntime(
51 tfrt_stub::Runtime::Create(4));
52 CreateServerCore(&server_core_);
55 static void TearDownTestSuite() { server_core_.reset(); }
58 static void CreateServerCore(std::unique_ptr<ServerCore>* server_core) {
59 ModelServerConfig config;
60 auto model_config = config.mutable_model_config_list()->add_config();
61 model_config->set_name(kTestModelName);
62 model_config->set_base_path(
63 test_util::TestSrcDirPath(
"servables/tensorflow/"
64 "testdata/saved_model_half_plus_two_cpu"));
65 model_config->set_model_platform(kTensorFlowModelPlatform);
69 ServerCore::Options options;
70 options.model_server_config = config;
71 PlatformConfigMap platform_config_map;
72 ::google::protobuf::Any source_adapter_config;
73 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
74 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
75 (*(*platform_config_map
76 .mutable_platform_configs())[kTensorFlowModelPlatform]
77 .mutable_source_adapter_config()) = source_adapter_config;
78 options.platform_config_map = platform_config_map;
79 options.aspired_version_policy =
80 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
83 options.num_initial_load_threads = options.num_load_threads;
87 ServerCore* GetServerCore() {
return server_core_.get(); }
89 Status GetServableHandle(ServableHandle<Servable>* servable) {
91 model_spec.set_name(kTestModelName);
92 return GetServerCore()->GetServableHandle(model_spec, servable);
95 const int64_t servable_version_ = kTestModelVersion;
98 static std::unique_ptr<ServerCore> server_core_;
101 std::unique_ptr<ServerCore> TfrtMultiInferenceTest::server_core_;
106 void AddInput(
const std::vector<std::pair<string, float>>& feature_kv,
107 MultiInferenceRequest* request) {
109 request->mutable_input()->mutable_example_list()->add_examples();
110 auto* features = example->mutable_features()->mutable_feature();
111 for (
const auto& feature : feature_kv) {
112 (*features)[feature.first].mutable_float_list()->add_value(feature.second);
116 void PopulateTask(
const string& signature_name,
const string& method_name,
117 int64_t version, InferenceTask* task) {
118 ModelSpec model_spec;
119 model_spec.set_name(kTestModelName);
121 model_spec.mutable_version()->set_value(version);
123 model_spec.set_signature_name(signature_name);
124 *task->mutable_model_spec() = model_spec;
125 task->set_method_name(method_name);
128 void ExpectStatusError(
const Status& status,
129 const absl::StatusCode expected_code,
130 const string& message_substring) {
131 ASSERT_EQ(expected_code, status.code());
132 EXPECT_THAT(status.message(), ::testing::HasSubstr(message_substring));
138 TEST_F(TfrtMultiInferenceTest, MissingInputTest) {
139 MultiInferenceRequest request;
140 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
142 MultiInferenceResponse response;
144 ServableHandle<Servable> servable;
145 TF_ASSERT_OK(GetServableHandle(&servable));
148 tfrt::SavedModel::RunOptions(), servable_version_,
149 &(down_cast<TfrtSavedModelServable*>(servable.get()))
152 absl::StatusCode::kInvalidArgument,
"Input is empty");
155 TEST_F(TfrtMultiInferenceTest, UndefinedSignatureTest) {
156 MultiInferenceRequest request;
157 AddInput({{
"x", 2}}, &request);
158 PopulateTask(
"ThisSignatureDoesNotExist", kRegressMethodName, -1,
159 request.add_tasks());
161 MultiInferenceResponse response;
162 ServableHandle<Servable> servable;
163 TF_ASSERT_OK(GetServableHandle(&servable));
164 ExpectStatusError(servable->MultiInference({}, request, &response),
165 absl::StatusCode::kInvalidArgument,
"not found");
169 TEST_F(TfrtMultiInferenceTest, InconsistentModelSpecsInRequestTest) {
170 MultiInferenceRequest request;
171 AddInput({{
"x", 2}}, &request);
173 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
176 ModelSpec model_spec;
177 model_spec.set_name(
"ModelDoesNotExist");
178 model_spec.set_signature_name(
"regress_x_to_y");
179 auto* task = request.add_tasks();
180 *task->mutable_model_spec() = model_spec;
181 task->set_method_name(kRegressMethodName);
183 MultiInferenceResponse response;
184 ServableHandle<Servable> servable;
185 TF_ASSERT_OK(GetServableHandle(&servable));
186 ExpectStatusError(servable->MultiInference({}, request, &response),
187 absl::StatusCode::kInvalidArgument,
188 "must access the same model name");
191 TEST_F(TfrtMultiInferenceTest, EvaluateDuplicateFunctionsTest) {
192 MultiInferenceRequest request;
193 AddInput({{
"x", 2}}, &request);
194 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
196 PopulateTask(
"regress_x_to_y", kRegressMethodName, -1, request.add_tasks());
198 MultiInferenceResponse response;
199 ServableHandle<Servable> servable;
200 TF_ASSERT_OK(GetServableHandle(&servable));
201 ExpectStatusError(servable->MultiInference({}, request, &response),
202 absl::StatusCode::kInvalidArgument,
203 "Duplicate evaluation of signature: regress_x_to_y");
206 TEST_F(TfrtMultiInferenceTest, UsupportedSignatureTypeTest) {
207 MultiInferenceRequest request;
208 AddInput({{
"x", 2}}, &request);
209 PopulateTask(
"serving_default", kPredictMethodName, -1, request.add_tasks());
211 MultiInferenceResponse response;
212 ServableHandle<Servable> servable;
213 TF_ASSERT_OK(GetServableHandle(&servable));
214 ExpectStatusError(servable->MultiInference({}, request, &response),
215 absl::StatusCode::kUnimplemented,
"Unsupported signature");
218 TEST_F(TfrtMultiInferenceTest, ValidSingleSignatureTest) {
219 MultiInferenceRequest request;
220 AddInput({{
"x", 2}}, &request);
221 PopulateTask(
"regress_x_to_y", kRegressMethodName, servable_version_,
222 request.add_tasks());
224 MultiInferenceResponse expected_response;
225 auto* inference_result = expected_response.add_results();
226 auto* model_spec = inference_result->mutable_model_spec();
227 *model_spec = request.tasks(0).model_spec();
228 model_spec->mutable_version()->set_value(servable_version_);
229 auto* regression_result = inference_result->mutable_regression_result();
230 regression_result->add_regressions()->set_value(3.0);
232 MultiInferenceResponse response;
233 ServableHandle<Servable> servable;
234 TF_ASSERT_OK(GetServableHandle(&servable));
235 TF_ASSERT_OK(servable->MultiInference({}, request, &response));
236 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
239 TEST_F(TfrtMultiInferenceTest, MultipleValidRegressSignaturesTest) {
240 MultiInferenceRequest request;
241 AddInput({{
"x", 2}}, &request);
242 PopulateTask(
"regress_x_to_y", kRegressMethodName, servable_version_,
243 request.add_tasks());
244 PopulateTask(
"regress_x_to_y2", kRegressMethodName, servable_version_,
245 request.add_tasks());
247 MultiInferenceResponse expected_response;
250 auto* inference_result_1 = expected_response.add_results();
251 auto* model_spec_1 = inference_result_1->mutable_model_spec();
252 *model_spec_1 = request.tasks(0).model_spec();
253 model_spec_1->mutable_version()->set_value(servable_version_);
254 auto* regression_result_1 = inference_result_1->mutable_regression_result();
255 regression_result_1->add_regressions()->set_value(3.0);
258 auto* inference_result_2 = expected_response.add_results();
259 auto* model_spec_2 = inference_result_2->mutable_model_spec();
260 *model_spec_2 = request.tasks(1).model_spec();
261 model_spec_2->mutable_version()->set_value(servable_version_);
262 auto* regression_result_2 = inference_result_2->mutable_regression_result();
263 regression_result_2->add_regressions()->set_value(4.0);
265 MultiInferenceResponse response;
266 ServableHandle<Servable> servable;
267 TF_ASSERT_OK(GetServableHandle(&servable));
268 TF_ASSERT_OK(servable->MultiInference({}, request, &response));
269 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
272 TEST_F(TfrtMultiInferenceTest, RegressAndClassifySignaturesTest) {
273 MultiInferenceRequest request;
274 AddInput({{
"x", 2}}, &request);
275 PopulateTask(
"regress_x_to_y", kRegressMethodName, servable_version_,
276 request.add_tasks());
277 PopulateTask(
"classify_x_to_y", kClassifyMethodName, servable_version_,
278 request.add_tasks());
280 MultiInferenceResponse expected_response;
281 auto* inference_result_1 = expected_response.add_results();
282 auto* model_spec_1 = inference_result_1->mutable_model_spec();
283 *model_spec_1 = request.tasks(0).model_spec();
284 model_spec_1->mutable_version()->set_value(servable_version_);
285 auto* regression_result = inference_result_1->mutable_regression_result();
286 regression_result->add_regressions()->set_value(3.0);
288 auto* inference_result_2 = expected_response.add_results();
289 auto* model_spec_2 = inference_result_2->mutable_model_spec();
290 *model_spec_2 = request.tasks(1).model_spec();
291 model_spec_2->mutable_version()->set_value(servable_version_);
292 auto* classification_result =
293 inference_result_2->mutable_classification_result();
294 classification_result->add_classifications()->add_classes()->set_score(3.0);
296 MultiInferenceResponse response;
297 ServableHandle<Servable> servable;
298 TF_ASSERT_OK(GetServableHandle(&servable));
299 TF_ASSERT_OK(servable->MultiInference({}, request, &response));
300 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
static Status Create(Options options, std::unique_ptr< ServerCore > *core)