16 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/threadpool_options.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow_serving/config/model_server_config.pb.h"
28 #include "tensorflow_serving/core/availability_preserving_policy.h"
29 #include "tensorflow_serving/model_servers/model_platform_types.h"
30 #include "tensorflow_serving/model_servers/platform_config_util.h"
31 #include "tensorflow_serving/model_servers/server_core.h"
32 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
35 namespace tensorflow {
39 constexpr
char kTestModelName[] =
"test_model";
43 class ClassificationServiceTest :
public ::testing::Test {
45 static void SetUpTestSuite() {
46 ModelServerConfig config;
47 auto model_config = config.mutable_model_config_list()->add_config();
48 model_config->set_name(kTestModelName);
49 model_config->set_base_path(test_util::TensorflowTestSrcDirPath(
50 "cc/saved_model/testdata/half_plus_two"));
51 model_config->set_model_platform(kTensorFlowModelPlatform);
55 ServerCore::Options options;
56 options.model_server_config = config;
57 options.platform_config_map =
58 CreateTensorFlowPlatformConfigMap(SessionBundleConfig());
59 options.aspired_version_policy =
60 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
63 options.num_initial_load_threads = options.num_load_threads;
67 static void TearDownTestSuite() { server_core_ =
nullptr; }
70 static std::unique_ptr<ServerCore> server_core_;
73 std::unique_ptr<ServerCore> ClassificationServiceTest::server_core_;
77 TEST_F(ClassificationServiceTest, InvalidModelSpec) {
78 ClassificationRequest request;
79 ClassificationResponse response;
82 EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
83 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
86 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
89 auto* model_spec = request.mutable_model_spec();
90 EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
91 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
94 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
97 model_spec->set_name(
"foo");
98 EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
99 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
102 tensorflow::error::NOT_FOUND);
107 TEST_F(ClassificationServiceTest, InvalidSignature) {
108 auto request = test_util::CreateProto<ClassificationRequest>(
110 " name: \"test_model\""
111 " signature_name: \"invalid_signature_name\""
113 ClassificationResponse response;
114 EXPECT_EQ(TensorflowClassificationServiceImpl::Classify(
115 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
118 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument));
124 TEST_F(ClassificationServiceTest, ClassificationSuccess) {
125 auto request = test_util::CreateProto<ClassificationRequest>(
127 " name: \"test_model\""
128 " signature_name: \"classify_x_to_y\""
146 " value: [ \"pt_BR\" ]"
162 ClassificationResponse response;
163 TF_EXPECT_OK(TensorflowClassificationServiceImpl::Classify(
164 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(), request,
166 EXPECT_THAT(response,
167 test_util::EqualsProto(
168 "result { classifications { classes { score: 42 } } }"
170 " name: \"test_model\""
171 " signature_name: \"classify_x_to_y\""
172 " version { value: 123 }"
178 TEST_F(ClassificationServiceTest, ModelSpecOverride) {
179 auto request = test_util::CreateProto<ClassificationRequest>(
181 " name: \"test_model\""
183 auto model_spec_override =
184 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
186 ClassificationResponse response;
187 EXPECT_NE(tensorflow::error::NOT_FOUND,
188 TensorflowClassificationServiceImpl::Classify(
189 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
192 EXPECT_EQ(tensorflow::error::NOT_FOUND,
193 TensorflowClassificationServiceImpl::ClassifyWithModelSpec(
194 RunOptions(), server_core_.get(), thread::ThreadPoolOptions(),
195 model_spec_override, request, &response)
199 TEST_F(ClassificationServiceTest, ThreadPoolOptions) {
200 auto request = test_util::CreateProto<ClassificationRequest>(
202 " name: \"test_model\""
203 " signature_name: \"classify_x_to_y\""
221 " value: [ \"pt_BR\" ]"
238 test_util::CountingThreadPool inter_op_threadpool(Env::Default(),
"InterOp",
240 test_util::CountingThreadPool intra_op_threadpool(Env::Default(),
"IntraOp",
242 thread::ThreadPoolOptions thread_pool_options;
243 thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
244 thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
245 ClassificationResponse response;
246 TF_EXPECT_OK(TensorflowClassificationServiceImpl::Classify(
247 RunOptions(), server_core_.get(), thread_pool_options, request,
249 EXPECT_THAT(response,
250 test_util::EqualsProto(
251 "result { classifications { classes { score: 42 } } }"
253 " name: \"test_model\""
254 " signature_name: \"classify_x_to_y\""
255 " version { value: 123 }"
259 ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
static Status Create(Options options, std::unique_ptr< ServerCore > *core)