16 #include "tensorflow_serving/servables/tensorflow/regression_service.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/protobuf/config.pb.h"
26 #include "tensorflow_serving/config/model_server_config.pb.h"
27 #include "tensorflow_serving/core/availability_preserving_policy.h"
28 #include "tensorflow_serving/model_servers/model_platform_types.h"
29 #include "tensorflow_serving/model_servers/platform_config_util.h"
30 #include "tensorflow_serving/model_servers/server_core.h"
31 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
32 #include "tensorflow_serving/test_util/test_util.h"
34 namespace tensorflow {
38 constexpr
char kTestModelName[] =
"test_model";
42 class RegressionServiceTest :
public ::testing::Test {
44 static void SetUpTestSuite() {
45 ModelServerConfig config;
46 auto model_config = config.mutable_model_config_list()->add_config();
47 model_config->set_name(kTestModelName);
48 model_config->set_base_path(test_util::TensorflowTestSrcDirPath(
49 "cc/saved_model/testdata/half_plus_two"));
50 model_config->set_model_platform(kTensorFlowModelPlatform);
54 ServerCore::Options options;
55 options.model_server_config = config;
56 options.platform_config_map =
57 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
58 options.aspired_version_policy =
59 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
62 options.num_initial_load_threads = options.num_load_threads;
66 static void TearDownTestSuite() { server_core_ =
nullptr; }
69 static std::unique_ptr<ServerCore> server_core_;
72 std::unique_ptr<ServerCore> RegressionServiceTest::server_core_;
76 TEST_F(RegressionServiceTest, InvalidModelSpec) {
77 RegressionRequest request;
78 RegressionResponse response;
81 EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
82 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
85 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
88 auto* model_spec = request.mutable_model_spec();
89 EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
90 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
93 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
96 model_spec->set_name(
"foo");
97 EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
98 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
101 tensorflow::error::NOT_FOUND);
106 TEST_F(RegressionServiceTest, InvalidSignature) {
107 auto request = test_util::CreateProto<RegressionRequest>(
109 " name: \"test_model\""
110 " signature_name: \"invalid_signature_name\""
112 RegressionResponse response;
113 EXPECT_EQ(TensorflowRegressionServiceImpl::Regress(
114 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
117 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
123 TEST_F(RegressionServiceTest, RegressionSuccess) {
124 auto request = test_util::CreateProto<RegressionRequest>(
126 " name: \"test_model\""
127 " signature_name: \"regress_x_to_y\""
145 " value: [ \"pt_BR\" ]"
161 RegressionResponse response;
162 TF_EXPECT_OK(TensorflowRegressionServiceImpl::Regress(
163 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(), request,
165 EXPECT_THAT(response,
166 test_util::EqualsProto(
"result { regressions { value: 42 } }"
168 " name: \"test_model\""
169 " signature_name: \"regress_x_to_y\""
170 " version { value: 123 }"
176 TEST_F(RegressionServiceTest, ModelSpecOverride) {
177 auto request = test_util::CreateProto<RegressionRequest>(
179 " name: \"test_model\""
181 auto model_spec_override =
182 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
184 RegressionResponse response;
185 EXPECT_NE(tensorflow::error::NOT_FOUND,
186 TensorflowRegressionServiceImpl::Regress(
187 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
190 EXPECT_EQ(tensorflow::error::NOT_FOUND,
191 TensorflowRegressionServiceImpl::RegressWithModelSpec(
192 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
193 model_spec_override, request, &response)
197 TEST_F(RegressionServiceTest, ThreadPoolOptions) {
198 auto request = test_util::CreateProto<RegressionRequest>(
200 " name: \"test_model\""
201 " signature_name: \"regress_x_to_y\""
219 " value: [ \"pt_BR\" ]"
236 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
238 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
240 thread::ThreadPoolOptions thread_pool_options;
241 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
242 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
243 RegressionResponse response;
244 TF_EXPECT_OK(TensorflowRegressionServiceImpl::Regress(
245 RunOptions(), server_core_.get(), thread_pool_options, request,
247 EXPECT_THAT(response,
248 test_util::EqualsProto(
"result { regressions { value: 42 } }"
250 " name: \"test_model\""
251 " signature_name: \"regress_x_to_y\""
252 " version { value: 123 }"
255 ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)