15 #include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel.h"
22 #include "absl/status/status.h"
23 #include "absl/time/time.h"
24 #include "tensorflow/cc/client/client_session.h"
25 #include "tensorflow/cc/ops/const_op.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
30 #include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/cc/ops/remote_predict_op.h"
32 namespace tensorflow {
40 class MockPredictionService {
42 static absl::Status Create(
const string& target_address,
43 std::unique_ptr<MockPredictionService>* service) {
44 service->reset(
new MockPredictionService(target_address));
45 return ::absl::OkStatus();
48 absl::StatusOr<MockRpc*> CreateRpc(absl::Duration max_rpc_deadline) {
53 void Predict(MockRpc* rpc, PredictRequest* request, PredictResponse* response,
54 std::function<
void(absl::Status status)> callback);
56 static constexpr
char kGoodModel[] =
"good_model";
57 static constexpr
char kBadModel[] =
"bad_model";
60 MockPredictionService(
const string& target_address);
63 constexpr
char MockPredictionService::kGoodModel[];
64 constexpr
char MockPredictionService::kBadModel[];
66 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> AliasTensorMap;
68 MockPredictionService::MockPredictionService(
const string& target_address) {}
70 void MockPredictionService::Predict(
71 MockRpc* rpc, PredictRequest* request, PredictResponse* response,
72 std::function<
void(absl::Status status)> callback) {
74 std::string model_name = request->model_spec().name();
75 if (model_name == kGoodModel) {
76 *(response->mutable_model_spec()) = request->model_spec();
77 AliasTensorMap& inputs = *request->mutable_inputs();
78 AliasTensorMap& outputs = *response->mutable_outputs();
79 outputs[
"output0"] = inputs[
"input0"];
80 outputs[
"output1"] = inputs[
"input1"];
81 callback(::absl::OkStatus());
84 if (model_name == kBadModel) {
85 callback(absl::Status(absl::StatusCode::kAborted,
"Aborted"));
89 REGISTER_KERNEL_BUILDER(Name(
"TfServingRemotePredict").Device(DEVICE_CPU),
90 RemotePredictOp<MockPredictionService>);
92 using RemotePredict = ops::TfServingRemotePredict;
95 ::tensorflow::Status RunRemotePredict(
96 const string& model_name, std::vector<Tensor>* outputs,
97 const DataTypeSlice& output_types = {DT_INT32, DT_INT32},
98 const absl::optional<::absl::Duration> deadline = absl::nullopt,
99 bool fail_on_rpc_error =
true,
100 const string& target_address =
"target_address",
101 int64_t target_model_version = -1,
const string& signature_name =
"") {
102 const Scope scope = Scope::DisabledShapeInferenceScope();
104 auto input_tensor_aliases = ops::Const(
105 scope.WithOpName(
"input_tensor_aliases"), {
"input0",
"input1"});
106 auto input_tensors0 = ops::Const(scope.WithOpName(
"input_tensors0"), {1, 2});
107 auto input_tensors1 = ops::Const(scope.WithOpName(
"input_tensors1"), {3, 4});
108 auto output_tensor_aliases = ops::Const(
109 scope.WithOpName(
"output_tensor_aliases"), {
"output0",
"output1"});
110 std::vector<Output> fetch_outputs;
111 RemotePredict::Attrs attrs = RemotePredict::Attrs()
112 .TargetAddress(target_address)
113 .ModelName(model_name)
114 .SignatureName(signature_name);
116 if (target_model_version >= 0) {
117 attrs = attrs.ModelVersion(target_model_version);
119 if (deadline.has_value()) {
120 attrs = attrs.MaxRpcDeadlineMillis(absl::ToInt64Seconds(deadline.value()) *
123 attrs = attrs.FailOpOnRpcError(fail_on_rpc_error);
125 auto remote_predict = RemotePredict(
126 scope, input_tensor_aliases, {input_tensors0, input_tensors1},
127 output_tensor_aliases, output_types, attrs);
129 fetch_outputs = {remote_predict.status_code,
130 remote_predict.status_error_message};
131 fetch_outputs.insert(fetch_outputs.end(),
132 remote_predict.output_tensors.begin(),
133 remote_predict.output_tensors.end());
134 TF_RETURN_IF_ERROR(scope.status());
136 ClientSession session(scope);
137 return session.Run(fetch_outputs, outputs);
140 TEST(RemotePredictTest, TestSimple) {
141 std::vector<Tensor> outputs;
142 TF_ASSERT_OK(RunRemotePredict(
143 MockPredictionService::kGoodModel, &outputs));
144 ASSERT_EQ(4, outputs.size());
146 EXPECT_EQ(0, outputs[0].scalar<int>()());
147 EXPECT_EQ(
"", outputs[1].scalar<tensorflow::tstring>()());
148 test::ExpectTensorEqual<int>(outputs[2], test::AsTensor<int>({1, 2}));
149 test::ExpectTensorEqual<int>(outputs[3], test::AsTensor<int>({3, 4}));
152 TEST(RemotePredictTest, TestRpcError) {
153 std::vector<Tensor> outputs;
154 const auto status = RunRemotePredict(
155 MockPredictionService::kBadModel, &outputs);
156 ASSERT_FALSE(status.ok());
157 EXPECT_EQ(error::Code::ABORTED, status.code());
158 EXPECT_THAT(status.message(), ::testing::HasSubstr(
"Aborted"));
161 TEST(RemotePredictTest, TestRpcErrorReturnStatus) {
162 std::vector<Tensor> outputs;
165 const auto status = RunRemotePredict(
166 MockPredictionService::kBadModel, &outputs,
167 {DT_FLOAT, DT_FLOAT}, absl::nullopt,
170 EXPECT_TRUE(status.ok());
171 EXPECT_EQ(
static_cast<int>(error::Code::ABORTED), outputs[0].scalar<int>()());
172 EXPECT_EQ(
"Aborted", outputs[1].scalar<tensorflow::tstring>()());