16 #include "tensorflow_serving/servables/tensorflow/tfrt_regression_service.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
25 #include "tensorflow_serving/config/model_server_config.pb.h"
26 #include "tensorflow_serving/core/availability_preserving_policy.h"
27 #include "tensorflow_serving/model_servers/model_platform_types.h"
28 #include "tensorflow_serving/model_servers/platform_config_util.h"
29 #include "tensorflow_serving/model_servers/server_core.h"
30 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
31 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
32 #include "tensorflow_serving/test_util/test_util.h"
34 namespace tensorflow {
38 constexpr
char kTestModelName[] =
"test_model";
42 class TFRTRegressionServiceTest :
public ::testing::Test {
44 static void SetUpTestSuite() {
45 tfrt_stub::SetGlobalRuntime(
46 tfrt_stub::Runtime::Create(4));
48 ModelServerConfig config;
49 auto model_config = config.mutable_model_config_list()->add_config();
50 model_config->set_name(kTestModelName);
51 model_config->set_base_path(
52 test_util::TestSrcDirPath(
"servables/tensorflow/"
53 "testdata/saved_model_half_plus_two_cpu"));
54 model_config->set_model_platform(kTensorFlowModelPlatform);
58 ServerCore::Options options;
59 options.model_server_config = config;
60 PlatformConfigMap platform_config_map;
61 ::google::protobuf::Any source_adapter_config;
62 TfrtSavedModelSourceAdapterConfig saved_model_bundle_source_adapter_config;
63 source_adapter_config.PackFrom(saved_model_bundle_source_adapter_config);
64 (*(*platform_config_map
65 .mutable_platform_configs())[kTensorFlowModelPlatform]
66 .mutable_source_adapter_config()) = source_adapter_config;
67 options.platform_config_map = platform_config_map;
68 options.aspired_version_policy =
69 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
72 options.num_initial_load_threads = options.num_load_threads;
76 static void TearDownTestSuite() { server_core_ =
nullptr; }
79 static std::unique_ptr<ServerCore> server_core_;
80 Servable::RunOptions run_options_;
83 std::unique_ptr<ServerCore> TFRTRegressionServiceTest::server_core_;
87 TEST_F(TFRTRegressionServiceTest, InvalidModelSpec) {
88 RegressionRequest request;
89 RegressionResponse response;
92 EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
95 absl::StatusCode::kInvalidArgument);
98 auto* model_spec = request.mutable_model_spec();
99 EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
102 absl::StatusCode::kInvalidArgument);
105 model_spec->set_name(
"foo");
106 EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
109 tensorflow::error::NOT_FOUND);
114 TEST_F(TFRTRegressionServiceTest, InvalidSignature) {
115 auto request = test_util::CreateProto<RegressionRequest>(
117 " name: \"test_model\""
118 " signature_name: \"invalid_signature_name\""
120 RegressionResponse response;
121 EXPECT_EQ(TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
124 tensorflow::error::FAILED_PRECONDITION);
130 TEST_F(TFRTRegressionServiceTest, RegressionSuccess) {
131 auto request = test_util::CreateProto<RegressionRequest>(
133 " name: \"test_model\""
134 " signature_name: \"regress_x_to_y\""
152 " value: [ \"pt_BR\" ]"
168 RegressionResponse response;
169 TF_EXPECT_OK(TFRTRegressionServiceImpl::Regress(
170 run_options_, server_core_.get(), request, &response));
171 EXPECT_THAT(response,
172 test_util::EqualsProto(
"result { regressions { value: 42 } }"
174 " name: \"test_model\""
175 " signature_name: \"regress_x_to_y\""
176 " version { value: 123 }"
182 TEST_F(TFRTRegressionServiceTest, ModelSpecOverride) {
183 auto request = test_util::CreateProto<RegressionRequest>(
185 " name: \"test_model\""
187 auto model_spec_override =
188 test_util::CreateProto<ModelSpec>(
"name: \"nonexistent_model\"");
190 RegressionResponse response;
191 EXPECT_NE(tensorflow::error::NOT_FOUND,
192 TFRTRegressionServiceImpl::Regress(run_options_, server_core_.get(),
195 EXPECT_EQ(tensorflow::error::NOT_FOUND,
196 TFRTRegressionServiceImpl::RegressWithModelSpec(
197 run_options_, server_core_.get(), model_spec_override, request,
static Status Create(Options options, std::unique_ptr< ServerCore > *core)