16 #include "tensorflow_serving/servables/tensorflow/regressor.h"
23 #include "google/protobuf/map.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/cc/saved_model/signature_constants.h"
26 #include "tensorflow/core/example/example.pb.h"
27 #include "tensorflow/core/example/feature.pb.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/threadpool_options.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow_serving/apis/input.pb.h"
36 #include "tensorflow_serving/apis/model.pb.h"
37 #include "tensorflow_serving/apis/regression.pb.h"
38 #include "tensorflow_serving/core/test_util/mock_session.h"
39 #include "tensorflow_serving/servables/tensorflow/util.h"
40 #include "tensorflow_serving/test_util/test_util.h"
42 namespace tensorflow {
46 using test_util::EqualsProto;
47 using test_util::MockSession;
50 const char kInputTensor[] =
"input:0";
51 const char kOutputTensor[] =
"output:0";
52 const char kOutputPlusOneTensor[] =
"outputPlusOne:0";
53 const char kImproperlySizedOutputTensor[] =
"ImproperlySizedOutput:0";
54 const char kOutputFeature[] =
"output";
56 const char kOutputPlusOneSignature[] =
"output_plus_one";
57 const char kInvalidNamedSignature[] =
"invalid_classification_signature";
58 const char kImproperlySizedOutputSignature[] =
"ImproperlySizedOutputSignature";
63 class FakeSession :
public tensorflow::Session {
65 explicit FakeSession(absl::optional<int64_t> expected_timeout)
66 : expected_timeout_(expected_timeout) {}
67 ~FakeSession()
override =
default;
68 Status Create(
const GraphDef& graph)
override {
69 return errors::Unimplemented(
"not available in fake");
71 Status Extend(
const GraphDef& graph)
override {
72 return errors::Unimplemented(
"not available in fake");
75 Status Close()
override {
76 return errors::Unimplemented(
"not available in fake");
79 Status ListDevices(std::vector<DeviceAttributes>* response)
override {
80 return errors::Unimplemented(
"not available in fake");
83 Status Run(
const std::vector<std::pair<string, Tensor>>& inputs,
84 const std::vector<string>& output_names,
85 const std::vector<string>& target_nodes,
86 std::vector<Tensor>* outputs)
override {
87 if (expected_timeout_) {
88 LOG(FATAL) <<
"Run() without RunOptions not expected to be called";
90 RunMetadata run_metadata;
91 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
95 Status Run(
const RunOptions& run_options,
96 const std::vector<std::pair<string, Tensor>>& inputs,
97 const std::vector<string>& output_names,
98 const std::vector<string>& target_nodes,
99 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override {
100 return Run(run_options, inputs, output_names, target_nodes, outputs,
101 run_metadata, thread::ThreadPoolOptions());
104 Status Run(
const RunOptions& run_options,
105 const std::vector<std::pair<string, Tensor>>& inputs,
106 const std::vector<string>& output_names,
107 const std::vector<string>& target_nodes,
108 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
109 const thread::ThreadPoolOptions& thread_pool_options)
override {
110 if (expected_timeout_) {
111 CHECK_EQ(*expected_timeout_, run_options.timeout_in_ms());
113 if (inputs.size() != 1 || inputs[0].first != kInputTensor) {
114 return errors::Internal(
"Expected one input Tensor.");
116 const Tensor& input = inputs[0].second;
117 std::vector<Example> examples;
118 TF_RETURN_IF_ERROR(GetExamples(input, &examples));
120 TF_RETURN_IF_ERROR(GetOutputTensor(examples, output_names[0], &output));
121 outputs->push_back(output);
122 return absl::OkStatus();
126 static Status GetExamples(
const Tensor& input,
127 std::vector<Example>* examples) {
129 const int batch_size = input.dim_size(0);
130 const auto& flat_input = input.flat<tstring>();
131 for (
int i = 0; i < batch_size; ++i) {
133 if (!example.ParseFromArray(flat_input(i).data(), flat_input(i).size())) {
134 return errors::Internal(
"failed to parse example");
136 examples->push_back(example);
138 return absl::OkStatus();
143 static Feature GetFeature(
const Example& example,
const string& name) {
144 const auto it = example.features().feature().find(name);
145 if (it != example.features().feature().end()) {
154 static Status GetOutputTensor(
const std::vector<Example>& examples,
155 const string& output_tensor_name,
157 if (examples.empty()) {
158 return errors::Internal(
"empty example list");
160 const int batch_size = examples.size();
161 if (output_tensor_name == kImproperlySizedOutputTensor) {
164 *tensor = Tensor(DT_FLOAT, TensorShape({batch_size, 1, 10}));
165 return absl::OkStatus();
169 *tensor = output_tensor_name == kOutputPlusOneTensor
170 ? Tensor(DT_FLOAT, TensorShape({batch_size, 1}))
171 : Tensor(DT_FLOAT, TensorShape({batch_size}));
173 const float offset = output_tensor_name == kOutputPlusOneTensor ? 1 : 0;
174 for (
int i = 0; i < batch_size; ++i) {
175 const Feature feature = GetFeature(examples[i], kOutputFeature);
176 if (feature.float_list().value_size() != 1) {
177 return errors::Internal(
"incorrect number of values in output feature");
179 tensor->flat<
float>()(i) = feature.float_list().value(0) + offset;
181 return absl::OkStatus();
185 const absl::optional<int64_t> expected_timeout_;
188 class RegressorTest :
public ::testing::TestWithParam<bool> {
190 void SetUp()
override {
191 SetSignatureMethodNameCheckFeature(IsMethodNameCheckEnabled());
192 saved_model_bundle_.reset(
new SavedModelBundle);
193 meta_graph_def_ = &saved_model_bundle_->meta_graph_def;
194 absl::optional<int64_t> expected_timeout = GetRunOptions().timeout_in_ms();
195 fake_session_ =
new FakeSession(expected_timeout);
196 saved_model_bundle_->session.reset(fake_session_);
198 auto* signature_defs = meta_graph_def_->mutable_signature_def();
199 SignatureDef sig_def;
200 TensorInfo input_tensor_info;
201 input_tensor_info.set_name(kInputTensor);
202 (*sig_def.mutable_inputs())[kRegressInputs] = input_tensor_info;
203 TensorInfo scores_tensor_info;
204 scores_tensor_info.set_name(kOutputTensor);
205 (*sig_def.mutable_outputs())[kRegressOutputs] = scores_tensor_info;
206 if (IsMethodNameCheckEnabled()) sig_def.set_method_name(kRegressMethodName);
207 (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
209 AddNamedSignatureToSavedModelBundle(
210 kInputTensor, kOutputPlusOneTensor, kOutputPlusOneSignature,
211 true , meta_graph_def_);
212 AddNamedSignatureToSavedModelBundle(
213 kInputTensor, kOutputPlusOneTensor, kInvalidNamedSignature,
214 false , meta_graph_def_);
217 AddNamedSignatureToSavedModelBundle(
218 kInputTensor, kImproperlySizedOutputTensor,
219 kImproperlySizedOutputSignature,
true ,
224 bool IsMethodNameCheckEnabled() {
return GetParam(); }
227 Example example_with_output(
const float output) {
229 feature.mutable_float_list()->add_value(output);
231 (*example.mutable_features()->mutable_feature())[
"output"] = feature;
236 std::unique_ptr<SavedModelBundle> saved_model(
new SavedModelBundle);
237 saved_model->meta_graph_def = saved_model_bundle_->meta_graph_def;
238 saved_model->session = std::move(saved_model_bundle_->session);
239 return CreateRegressorFromSavedModelBundle(
240 GetRunOptions(), std::move(saved_model), ®ressor_);
243 RunOptions GetRunOptions()
const {
244 RunOptions run_options;
245 run_options.set_timeout_in_ms(42);
252 void AddNamedSignatureToSavedModelBundle(
253 const string& input_tensor_name,
const string& output_scores_tensor_name,
254 const string& signature_name,
const bool is_regression,
255 tensorflow::MetaGraphDef* meta_graph_def) {
256 auto* signature_defs = meta_graph_def->mutable_signature_def();
257 SignatureDef sig_def;
260 TensorInfo input_tensor_info;
261 input_tensor_info.set_name(input_tensor_name);
262 (*sig_def.mutable_inputs())[kRegressInputs] = input_tensor_info;
263 TensorInfo scores_tensor_info;
264 scores_tensor_info.set_name(output_scores_tensor_name);
265 (*sig_def.mutable_outputs())[kRegressOutputs] = scores_tensor_info;
266 method_name = kRegressMethodName;
268 TensorInfo input_tensor_info;
269 input_tensor_info.set_name(input_tensor_name);
270 (*sig_def.mutable_inputs())[kClassifyInputs] = input_tensor_info;
271 TensorInfo class_tensor_info;
272 class_tensor_info.set_name(kOutputPlusOneTensor);
273 (*sig_def.mutable_outputs())[kClassifyOutputClasses] = class_tensor_info;
274 method_name = kClassifyMethodName;
276 if (IsMethodNameCheckEnabled()) sig_def.set_method_name(method_name);
277 (*signature_defs)[signature_name] = sig_def;
281 tensorflow::MetaGraphDef* meta_graph_def_;
282 FakeSession* fake_session_;
283 std::unique_ptr<SavedModelBundle> saved_model_bundle_;
286 std::unique_ptr<RegressorInterface> regressor_;
289 RegressionRequest request_;
290 RegressionResult result_;
293 TEST_P(RegressorTest, BasicExampleList) {
294 TF_ASSERT_OK(Create());
296 request_.mutable_input()->mutable_example_list()->mutable_examples();
297 *examples->Add() = example_with_output(2.0);
298 *examples->Add() = example_with_output(3.0);
299 TF_ASSERT_OK(regressor_->Regress(request_, &result_));
300 EXPECT_THAT(result_, EqualsProto(
" regressions { "
306 RegressionResponse response;
307 TF_ASSERT_OK(RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def,
308 {}, fake_session_, request_, &response));
309 EXPECT_THAT(response.result(), EqualsProto(
" regressions { "
317 TEST_P(RegressorTest, BasicExampleListWithContext) {
318 TF_ASSERT_OK(Create());
319 auto* list_with_context =
320 request_.mutable_input()->mutable_example_list_with_context();
322 list_with_context->add_examples();
323 list_with_context->add_examples();
325 *list_with_context->mutable_context() = example_with_output(3.0);
326 TF_ASSERT_OK(regressor_->Regress(request_, &result_));
327 EXPECT_THAT(result_, EqualsProto(
" regressions { "
333 RegressionResponse response;
334 TF_ASSERT_OK(RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def,
335 {}, fake_session_, request_, &response));
336 EXPECT_THAT(response.result(), EqualsProto(
" regressions { "
344 TEST_P(RegressorTest, ValidNamedSignature) {
345 TF_ASSERT_OK(Create());
346 request_.mutable_model_spec()->set_signature_name(kOutputPlusOneSignature);
348 request_.mutable_input()->mutable_example_list()->mutable_examples();
349 *examples->Add() = example_with_output(2.0);
350 *examples->Add() = example_with_output(3.0);
351 TF_ASSERT_OK(regressor_->Regress(request_, &result_));
352 EXPECT_THAT(result_, EqualsProto(
" regressions { "
359 RegressionResponse response;
360 TF_ASSERT_OK(RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def,
361 {}, fake_session_, request_, &response));
362 EXPECT_THAT(response.result(), EqualsProto(
" regressions { "
370 TEST_P(RegressorTest, InvalidNamedSignature) {
371 TF_ASSERT_OK(Create());
372 request_.mutable_model_spec()->set_signature_name(kInvalidNamedSignature);
374 request_.mutable_input()->mutable_example_list()->mutable_examples();
375 *examples->Add() = example_with_output(2.0);
376 *examples->Add() = example_with_output(3.0);
377 Status status = regressor_->Regress(request_, &result_);
378 ASSERT_FALSE(status.ok());
379 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
383 RegressionResponse response;
384 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
385 fake_session_, request_, &response);
386 ASSERT_FALSE(status.ok());
387 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
392 TEST_P(RegressorTest, MalformedOutputs) {
393 TF_ASSERT_OK(Create());
394 request_.mutable_model_spec()->set_signature_name(
395 kImproperlySizedOutputSignature);
397 request_.mutable_input()->mutable_example_list()->mutable_examples();
398 *examples->Add() = example_with_output(2.0);
399 *examples->Add() = example_with_output(3.0);
400 Status status = regressor_->Regress(request_, &result_);
402 ASSERT_FALSE(status.ok());
403 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
407 RegressionResponse response;
408 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
409 fake_session_, request_, &response);
410 ASSERT_FALSE(status.ok());
411 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
416 TEST_P(RegressorTest, EmptyInput) {
417 TF_ASSERT_OK(Create());
419 request_.mutable_input();
420 Status status = regressor_->Regress(request_, &result_);
421 ASSERT_FALSE(status.ok());
422 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
423 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
424 RegressionResponse response;
425 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
426 fake_session_, request_, &response);
427 ASSERT_FALSE(status.ok());
428 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
429 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
432 TEST_P(RegressorTest, EmptyExampleList) {
433 TF_ASSERT_OK(Create());
434 request_.mutable_input()->mutable_example_list();
435 Status status = regressor_->Regress(request_, &result_);
436 ASSERT_FALSE(status.ok());
437 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
438 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
439 RegressionResponse response;
440 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
441 fake_session_, request_, &response);
442 ASSERT_FALSE(status.ok());
443 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
444 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
447 TEST_P(RegressorTest, EmptyExampleListWithContext) {
448 TF_ASSERT_OK(Create());
450 *request_.mutable_input()
451 ->mutable_example_list_with_context()
452 ->mutable_context() = example_with_output(3);
453 Status status = regressor_->Regress(request_, &result_);
454 ASSERT_FALSE(status.ok());
455 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
456 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
457 RegressionResponse response;
458 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
459 fake_session_, request_, &response);
460 ASSERT_FALSE(status.ok());
461 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
462 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Input is empty"));
465 TEST_P(RegressorTest, RunsFails) {
466 MockSession* mock =
new MockSession;
467 saved_model_bundle_->session.reset(mock);
468 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
470 ::testing::Return(errors::Internal(
"Run totally failed")));
471 TF_ASSERT_OK(Create());
472 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
473 example_with_output(2.0);
474 Status status = regressor_->Regress(request_, &result_);
475 ASSERT_FALSE(status.ok());
476 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"Run totally failed"));
477 RegressionResponse response;
478 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
479 mock, request_, &response);
480 ASSERT_FALSE(status.ok());
481 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"Run totally failed"));
484 TEST_P(RegressorTest, UnexpectedOutputTensorSize) {
485 MockSession* mock =
new MockSession;
486 saved_model_bundle_->session.reset(mock);
487 std::vector<Tensor> outputs = {Tensor(DT_FLOAT, TensorShape({2}))};
488 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
489 .WillOnce(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
490 ::testing::Return(absl::OkStatus())));
491 TF_ASSERT_OK(Create());
492 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
493 example_with_output(2.0);
494 Status status = regressor_->Regress(request_, &result_);
495 ASSERT_FALSE(status.ok());
496 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"output batch size"));
497 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
498 .WillOnce(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
499 ::testing::Return(absl::OkStatus())));
500 RegressionResponse response;
501 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
502 mock, request_, &response);
503 ASSERT_FALSE(status.ok());
504 EXPECT_THAT(status.ToString(), ::testing::HasSubstr(
"output batch size"));
507 TEST_P(RegressorTest, UnexpectedOutputTensorType) {
508 MockSession* mock =
new MockSession;
509 saved_model_bundle_->session.reset(mock);
511 std::vector<Tensor> outputs = {Tensor(DT_STRING, TensorShape({1}))};
512 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
513 .WillOnce(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
514 ::testing::Return(absl::OkStatus())));
515 TF_ASSERT_OK(Create());
516 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
517 example_with_output(2.0);
518 Status status = regressor_->Regress(request_, &result_);
519 ASSERT_FALSE(status.ok());
520 EXPECT_THAT(status.ToString(),
521 ::testing::HasSubstr(
"Expected output Tensor of DT_FLOAT"));
522 EXPECT_CALL(*mock, Run(_, _, _, _, _, _, _))
523 .WillOnce(::testing::DoAll(::testing::SetArgPointee<4>(outputs),
524 ::testing::Return(absl::OkStatus())));
525 RegressionResponse response;
526 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
527 mock, request_, &response);
528 ASSERT_FALSE(status.ok());
529 EXPECT_THAT(status.ToString(),
530 ::testing::HasSubstr(
"Expected output Tensor of DT_FLOAT"));
533 TEST_P(RegressorTest, MissingRegressionSignature) {
534 auto* signature_defs = meta_graph_def_->mutable_signature_def();
535 SignatureDef sig_def;
536 (*signature_defs)[kDefaultServingSignatureDefKey] = sig_def;
537 TF_ASSERT_OK(Create());
539 feature.mutable_bytes_list()->add_value(
"uno");
541 (*example.mutable_features()->mutable_feature())[
"class"] = feature;
542 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
545 Status status = regressor_->Regress(request_, &result_);
546 ASSERT_FALSE(status.ok());
547 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
550 RegressionResponse response;
551 status = RunRegress(GetRunOptions(), saved_model_bundle_->meta_graph_def, {},
552 fake_session_, request_, &response);
553 ASSERT_FALSE(status.ok());
554 EXPECT_EQ(
static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
559 TEST_P(RegressorTest, MethodNameCheck) {
560 RegressionResponse response;
561 *request_.mutable_input()->mutable_example_list()->mutable_examples()->Add() =
562 example_with_output(2.0);
563 auto* signature_defs = meta_graph_def_->mutable_signature_def();
566 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
568 TF_EXPECT_OK(RunRegress(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
569 request_, &response));
572 (*signature_defs)[kDefaultServingSignatureDefKey].set_method_name(
573 "not/supported/method");
574 EXPECT_EQ(RunRegress(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
577 !IsMethodNameCheckEnabled());
580 (*signature_defs)[kDefaultServingSignatureDefKey].clear_method_name();
581 EXPECT_EQ(RunRegress(GetRunOptions(), *meta_graph_def_, {}, fake_session_,
584 !IsMethodNameCheckEnabled());
587 INSTANTIATE_TEST_SUITE_P(Regressor, RegressorTest, ::testing::Bool());