16 #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/cc/saved_model/signature_constants.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow_serving/core/availability_preserving_policy.h"
26 #include "tensorflow_serving/model_servers/model_platform_types.h"
27 #include "tensorflow_serving/model_servers/platform_config_util.h"
28 #include "tensorflow_serving/model_servers/server_core.h"
29 #include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h"
30 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
31 #include "tensorflow_serving/servables/tensorflow/test_util/fake_thread_pool_factory.h"
32 #include "tensorflow_serving/servables/tensorflow/test_util/fake_thread_pool_factory.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
34 #include "tensorflow_serving/util/oss_or_google.h"
36 namespace tensorflow {
40 constexpr
char kTestModelName[] =
"test_model";
41 constexpr
int kTestModelVersion = 123;
43 const char kInputTensorKey[] =
"x";
44 const char kOutputTensorKey[] =
"y";
46 class PredictImplTest :
public ::testing::Test {
48 static void SetUpTestSuite() {
49 TF_ASSERT_OK(CreateServerCore(test_util::TensorflowTestSrcDirPath(
50 "cc/saved_model/testdata/half_plus_two"),
51 &saved_model_server_core_));
52 TF_ASSERT_OK(CreateServerCore(
53 test_util::TestSrcDirPath(
54 "/servables/tensorflow/testdata/saved_model_counter"),
55 &saved_model_server_core_counter_model_));
58 static void TearDownTestSuite() {
60 server_core_bad_model_.reset();
61 saved_model_server_core_.reset();
62 saved_model_server_core_counter_model_.reset();
66 static Status CreateServerCore(
const string& model_path,
67 std::unique_ptr<ServerCore>* server_core) {
68 ModelServerConfig config;
69 auto model_config = config.mutable_model_config_list()->add_config();
70 model_config->set_name(kTestModelName);
71 model_config->set_base_path(model_path);
72 model_config->set_model_platform(kTensorFlowModelPlatform);
76 ServerCore::Options options;
77 options.model_server_config = config;
78 options.platform_config_map =
79 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
80 options.aspired_version_policy =
81 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
84 options.num_initial_load_threads = options.num_load_threads;
88 ServerCore* GetServerCore() {
return saved_model_server_core_.get(); }
90 ServerCore* GetServerCoreWithCounterModel() {
91 return saved_model_server_core_counter_model_.get();
94 RunOptions GetRunOptions() {
return RunOptions(); }
97 static std::unique_ptr<ServerCore> server_core_;
98 static std::unique_ptr<ServerCore> server_core_bad_model_;
99 static std::unique_ptr<ServerCore> saved_model_server_core_;
100 static std::unique_ptr<ServerCore> saved_model_server_core_counter_model_;
103 std::unique_ptr<ServerCore> PredictImplTest::server_core_;
104 std::unique_ptr<ServerCore> PredictImplTest::server_core_bad_model_;
105 std::unique_ptr<ServerCore> PredictImplTest::saved_model_server_core_;
106 std::unique_ptr<ServerCore>
107 PredictImplTest::saved_model_server_core_counter_model_;
109 TEST_F(PredictImplTest, MissingOrEmptyModelSpec) {
110 PredictRequest request;
111 PredictResponse response;
114 TensorflowPredictor predictor;
116 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
117 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
120 ModelSpec* model_spec = request.mutable_model_spec();
121 model_spec->clear_name();
125 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
126 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
130 model_spec->set_name(
"test");
132 tensorflow::error::NOT_FOUND,
133 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
137 TEST_F(PredictImplTest, EmptyInputList) {
138 PredictRequest request;
139 PredictResponse response;
141 ModelSpec* model_spec = request.mutable_model_spec();
142 model_spec->set_name(kTestModelName);
143 model_spec->mutable_version()->set_value(kTestModelVersion);
145 TensorflowPredictor predictor;
148 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
149 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
153 TEST_F(PredictImplTest, InputTensorsDontMatchModelSpecInputs) {
154 PredictRequest request;
155 PredictResponse response;
157 ModelSpec* model_spec = request.mutable_model_spec();
158 model_spec->set_name(kTestModelName);
159 model_spec->mutable_version()->set_value(kTestModelVersion);
161 TensorProto tensor_proto;
162 tensor_proto.add_string_val(
"any_key");
163 tensor_proto.set_dtype(tensorflow::DT_STRING);
164 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
166 TensorflowPredictor predictor;
167 auto inputs = request.mutable_inputs();
168 (*inputs)[
"key"] = tensor_proto;
170 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
171 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
175 TEST_F(PredictImplTest, OutputFiltersDontMatchModelSpecOutputs) {
176 PredictRequest request;
177 PredictResponse response;
179 ModelSpec* model_spec = request.mutable_model_spec();
180 model_spec->set_name(kTestModelName);
181 model_spec->mutable_version()->set_value(kTestModelVersion);
183 TensorProto tensor_proto;
184 tensor_proto.add_float_val(2.0);
185 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
186 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
187 request.add_output_filter(
"output_filter");
189 TensorflowPredictor predictor;
192 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
193 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
196 request.clear_output_filter();
197 request.add_output_filter(kOutputTensorKey);
199 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
200 request.add_output_filter(kOutputTensorKey);
204 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
205 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
209 TEST_F(PredictImplTest, InputTensorsHaveWrongType) {
210 PredictRequest request;
211 PredictResponse response;
213 ModelSpec* model_spec = request.mutable_model_spec();
214 model_spec->set_name(kTestModelName);
215 model_spec->mutable_version()->set_value(kTestModelVersion);
217 TensorProto tensor_proto;
218 tensor_proto.add_string_val(
"any_key");
219 tensor_proto.set_dtype(tensorflow::DT_STRING);
220 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
221 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
222 request.add_output_filter(kOutputTensorKey);
224 TensorflowPredictor predictor;
227 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
228 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response)
232 TEST_F(PredictImplTest, PredictionSuccess) {
233 PredictRequest request;
234 PredictResponse response;
236 ModelSpec* model_spec = request.mutable_model_spec();
237 model_spec->set_name(kTestModelName);
238 model_spec->mutable_version()->set_value(kTestModelVersion);
240 TensorProto tensor_proto;
241 tensor_proto.add_float_val(2.0);
242 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
243 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
245 TensorflowPredictor predictor;
247 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
248 TensorProto output_tensor_proto;
249 output_tensor_proto.add_float_val(3);
250 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
251 output_tensor_proto.mutable_tensor_shape();
252 PredictResponse expected_response;
253 *expected_response.mutable_model_spec() = *model_spec;
254 expected_response.mutable_model_spec()->set_signature_name(
255 kDefaultServingSignatureDefKey);
256 (*expected_response.mutable_outputs())[kOutputTensorKey] =
258 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
262 TEST_F(PredictImplTest, PredictionWithNamedRegressionSignature) {
263 PredictRequest request;
264 PredictResponse response;
266 ModelSpec* model_spec = request.mutable_model_spec();
267 model_spec->set_name(kTestModelName);
268 model_spec->mutable_version()->set_value(kTestModelVersion);
269 model_spec->set_signature_name(
"regress_x2_to_y3");
271 TensorProto tensor_proto;
272 tensor_proto.add_float_val(2.0);
273 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
274 (*request.mutable_inputs())[kRegressInputs] = tensor_proto;
275 TensorflowPredictor predictor;
277 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
278 TensorProto output_tensor_proto;
279 output_tensor_proto.add_float_val(4);
280 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
281 output_tensor_proto.mutable_tensor_shape();
282 PredictResponse expected_response;
283 *expected_response.mutable_model_spec() = *model_spec;
284 (*expected_response.mutable_outputs())[kRegressOutputs] = output_tensor_proto;
285 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
290 TEST_F(PredictImplTest, PredictionWithNamedClassificationSignature) {
291 PredictRequest request;
292 PredictResponse response;
294 ModelSpec* model_spec = request.mutable_model_spec();
295 model_spec->set_name(kTestModelName);
296 model_spec->mutable_version()->set_value(kTestModelVersion);
297 model_spec->set_signature_name(
"classify_x2_to_y3");
299 TensorProto tensor_proto;
300 tensor_proto.add_float_val(2.0);
301 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
302 (*request.mutable_inputs())[kClassifyInputs] = tensor_proto;
304 TensorflowPredictor predictor;
306 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
307 TensorProto output_tensor_proto;
308 output_tensor_proto.add_float_val(4);
309 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
310 output_tensor_proto.mutable_tensor_shape();
311 PredictResponse expected_response;
312 *expected_response.mutable_model_spec() = *model_spec;
313 (*expected_response.mutable_outputs())[kClassifyOutputScores] =
315 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
325 TEST_F(PredictImplTest, PredictionWithCustomizedSignatures) {
326 PredictRequest request;
327 PredictResponse response;
328 TensorflowPredictor predictor;
331 ModelSpec* model_spec = request.mutable_model_spec();
332 model_spec->set_name(kTestModelName);
333 model_spec->mutable_version()->set_value(kTestModelVersion);
334 model_spec->set_signature_name(
"get_counter");
336 TF_ASSERT_OK(predictor.Predict(
337 GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
339 PredictResponse expected_get_counter;
340 *expected_get_counter.mutable_model_spec() = *model_spec;
341 TensorProto output_get_counter;
342 output_get_counter.add_float_val(0);
343 output_get_counter.set_dtype(tensorflow::DT_FLOAT);
344 output_get_counter.mutable_tensor_shape();
345 (*expected_get_counter.mutable_outputs())[
"output"] = output_get_counter;
346 EXPECT_THAT(response, test_util::EqualsProto(expected_get_counter));
349 model_spec->set_signature_name(
"incr_counter");
350 TF_ASSERT_OK(predictor.Predict(
351 GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
353 PredictResponse expected_incr_counter;
354 *expected_incr_counter.mutable_model_spec() = *model_spec;
355 TensorProto output_incr_counter;
356 output_incr_counter.add_float_val(1);
357 output_incr_counter.set_dtype(tensorflow::DT_FLOAT);
358 output_incr_counter.mutable_tensor_shape();
359 (*expected_incr_counter.mutable_outputs())[
"output"] = output_incr_counter;
360 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter));
363 model_spec->set_signature_name(
"reset_counter");
364 TF_ASSERT_OK(predictor.Predict(
365 GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
367 PredictResponse expected_reset_counter;
368 *expected_reset_counter.mutable_model_spec() = *model_spec;
369 TensorProto output_reset_counter;
370 output_reset_counter.add_float_val(0);
371 output_reset_counter.set_dtype(tensorflow::DT_FLOAT);
372 output_reset_counter.mutable_tensor_shape();
373 (*expected_reset_counter.mutable_outputs())[
"output"] = output_reset_counter;
374 EXPECT_THAT(response, test_util::EqualsProto(expected_reset_counter));
377 model_spec->set_signature_name(
"incr_counter");
378 request.add_output_filter(
"output");
379 TF_ASSERT_OK(predictor.Predict(
380 GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
381 request.clear_output_filter();
383 PredictResponse expected_incr_counter2;
384 *expected_incr_counter2.mutable_model_spec() = *model_spec;
385 TensorProto output_incr_counter2;
386 output_incr_counter2.add_float_val(1);
387 output_incr_counter2.set_dtype(tensorflow::DT_FLOAT);
388 output_incr_counter2.mutable_tensor_shape();
389 (*expected_incr_counter2.mutable_outputs())[
"output"] = output_incr_counter2;
390 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter2));
393 model_spec->set_signature_name(
"incr_counter_by");
394 TensorProto tensor_proto;
395 tensor_proto.add_float_val(3);
396 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
397 (*request.mutable_inputs())[
"delta"] = tensor_proto;
399 TF_ASSERT_OK(predictor.Predict(
400 GetRunOptions(), GetServerCoreWithCounterModel(), request, &response));
402 PredictResponse expected_incr_counter_by;
403 *expected_incr_counter_by.mutable_model_spec() = *model_spec;
404 TensorProto output_incr_counter_by;
405 output_incr_counter_by.add_float_val(4);
406 output_incr_counter_by.set_dtype(tensorflow::DT_FLOAT);
407 output_incr_counter_by.mutable_tensor_shape();
408 (*expected_incr_counter_by.mutable_outputs())[
"output"] =
409 output_incr_counter_by;
410 EXPECT_THAT(response, test_util::EqualsProto(expected_incr_counter_by));
415 TEST_F(PredictImplTest, ModelSpecOverride) {
416 auto request = test_util::CreateProto<PredictRequest>(
418 " name: \"test_model\""
420 auto model_spec_override =
421 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
423 TensorflowPredictor predictor;
424 PredictResponse response;
425 EXPECT_NE(tensorflow::error::NOT_FOUND,
426 predictor.Predict(RunOptions(), GetServerCore(), request, &response)
428 EXPECT_EQ(tensorflow::error::NOT_FOUND,
430 .PredictWithModelSpec(RunOptions(), GetServerCore(),
431 model_spec_override, request, &response)
435 TEST_F(PredictImplTest, ThreadPoolFactory) {
436 PredictRequest request;
437 PredictResponse response;
439 ModelSpec* model_spec = request.mutable_model_spec();
440 model_spec->set_name(kTestModelName);
441 model_spec->mutable_version()->set_value(kTestModelVersion);
443 TensorProto tensor_proto;
444 tensor_proto.add_float_val(2.0);
445 tensor_proto.set_dtype(tensorflow::DT_FLOAT);
446 (*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
448 auto inter_op_threadpool =
449 std::make_shared<test_util::CountingThreadPool>(Env::Default(),
"InterOp",
451 auto intra_op_threadpool =
452 std::make_shared<test_util::CountingThreadPool>(Env::Default(),
"IntraOp",
454 test_util::FakeThreadPoolFactoryConfig fake_thread_pool_factory_config;
455 test_util::FakeThreadPoolFactory fake_thread_pool_factory(
456 fake_thread_pool_factory_config);
457 fake_thread_pool_factory.SetInterOpThreadPool(inter_op_threadpool);
458 fake_thread_pool_factory.SetIntraOpThreadPool(intra_op_threadpool);
460 TensorflowPredictor predictor(&fake_thread_pool_factory);
462 predictor.Predict(GetRunOptions(), GetServerCore(), request, &response));
463 TensorProto output_tensor_proto;
464 output_tensor_proto.add_float_val(3);
465 output_tensor_proto.set_dtype(tensorflow::DT_FLOAT);
466 output_tensor_proto.mutable_tensor_shape();
467 PredictResponse expected_response;
468 *expected_response.mutable_model_spec() = *model_spec;
469 expected_response.mutable_model_spec()->set_signature_name(
470 kDefaultServingSignatureDefKey);
471 (*expected_response.mutable_outputs())[kOutputTensorKey] =
473 EXPECT_THAT(response, test_util::EqualsProto(expected_response));
476 ASSERT_GE(inter_op_threadpool->NumScheduled(), 1);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)