16 #include "tensorflow_serving/servables/tensorflow/tfrt_classification_service.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/threadpool_options.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 #include "tensorflow/core/tfrt/runtime/runtime.h"
26 #include "tensorflow_serving/config/model_server_config.pb.h"
27 #include "tensorflow_serving/core/availability_preserving_policy.h"
28 #include "tensorflow_serving/model_servers/model_platform_types.h"
29 #include "tensorflow_serving/model_servers/platform_config_util.h"
30 #include "tensorflow_serving/model_servers/server_core.h"
31 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
32 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
33 #include "tensorflow_serving/test_util/test_util.h"
35 namespace tensorflow {
39 constexpr
char kTestModelName[] =
"test_model";
43 class TFRTClassificationServiceTest :
public ::testing::Test {
45 static void SetUpTestSuite() {
46 tfrt_stub::SetGlobalRuntime(
47 tfrt_stub::Runtime::Create(4));
49 ModelServerConfig config;
50 auto model_config = config.mutable_model_config_list()->add_config();
51 model_config->set_name(kTestModelName);
52 model_config->set_base_path(
53 test_util::TestSrcDirPath(
"servables/tensorflow/"
54 "testdata/saved_model_half_plus_two_cpu"));
55 model_config->set_model_platform(kTensorFlowModelPlatform);
59 ServerCore::Options options;
60 options.model_server_config = config;
61 PlatformConfigMap platform_config_map;
62 ::google::protobuf::Any source_adapter_config;
63 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
64 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
65 (*(*platform_config_map
66 .mutable_platform_configs())[kTensorFlowModelPlatform]
67 .mutable_source_adapter_config()) = source_adapter_config;
68 options.platform_config_map = platform_config_map;
69 options.aspired_version_policy =
70 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
73 options.num_initial_load_threads = options.num_load_threads;
77 static void TearDownTestSuite() { server_core_ =
nullptr; }
80 static std::unique_ptr<ServerCore> server_core_;
81 Servable::RunOptions run_options_;
84 std::unique_ptr<ServerCore> TFRTClassificationServiceTest::server_core_;
88 TEST_F(TFRTClassificationServiceTest, InvalidModelSpec) {
89 ClassificationRequest request;
90 ClassificationResponse response;
93 EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
94 run_options_, server_core_.get(), request, &response)
96 absl::StatusCode::kInvalidArgument);
99 auto* model_spec = request.mutable_model_spec();
100 EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
101 run_options_, server_core_.get(), request, &response)
103 absl::StatusCode::kInvalidArgument);
106 model_spec->set_name(
"foo");
107 EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
108 run_options_, server_core_.get(), request, &response)
110 tensorflow::error::NOT_FOUND);
115 TEST_F(TFRTClassificationServiceTest, InvalidSignature) {
116 auto request = test_util::CreateProto<ClassificationRequest>(
118 " name: \"test_model\""
119 " signature_name: \"invalid_signature_name\""
121 ClassificationResponse response;
122 EXPECT_EQ(TFRTClassificationServiceImpl::Classify(
123 run_options_, server_core_.get(), request, &response)
125 tensorflow::error::FAILED_PRECONDITION);
131 TEST_F(TFRTClassificationServiceTest, ClassificationSuccess) {
132 auto request = test_util::CreateProto<ClassificationRequest>(
134 " name: \"test_model\""
135 " signature_name: \"classify_x_to_y\""
153 " value: [ \"pt_BR\" ]"
169 ClassificationResponse response;
170 TF_EXPECT_OK(TFRTClassificationServiceImpl::Classify(
171 run_options_, server_core_.get(), request, &response));
172 EXPECT_THAT(response,
173 test_util::EqualsProto(
174 "result { classifications { classes { score: 42 } } }"
176 " name: \"test_model\""
177 " signature_name: \"classify_x_to_y\""
178 " version { value: 123 }"
184 TEST_F(TFRTClassificationServiceTest, ModelSpecOverride) {
185 auto request = test_util::CreateProto<ClassificationRequest>(
187 " name: \"test_model\""
189 auto model_spec_override =
190 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
192 ClassificationResponse response;
193 EXPECT_NE(tensorflow::error::NOT_FOUND,
194 TFRTClassificationServiceImpl::Classify(
195 run_options_, server_core_.get(), request, &response)
197 EXPECT_EQ(tensorflow::error::NOT_FOUND,
198 TFRTClassificationServiceImpl::ClassifyWithModelSpec(
199 run_options_, server_core_.get(), model_spec_override, request,
static Status Create(Options options, std::unique_ptr< ServerCore > *core)