TensorFlow Serving C++ API Documentation
remote_predict_op_kernel_test.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "absl/time/time.h"
24 #include "tensorflow/cc/client/client_session.h"
25 #include "tensorflow/cc/ops/const_op.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
30 #include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/cc/ops/remote_predict_op.h"
31 
32 namespace tensorflow {
33 namespace serving {
34 namespace {
35 
36 // Empty mock rpc class.
37 class MockRpc {};
38 
39 // Mock class for RemotePredict Op kernel test.
40 class MockPredictionService {
41  public:
42  static absl::Status Create(const string& target_address,
43  std::unique_ptr<MockPredictionService>* service) {
44  service->reset(new MockPredictionService(target_address));
45  return ::absl::OkStatus();
46  }
47 
48  absl::StatusOr<MockRpc*> CreateRpc(absl::Duration max_rpc_deadline) {
49  return new MockRpc;
50  }
51 
52  // The model_name in request determines response and/or status.
53  void Predict(MockRpc* rpc, PredictRequest* request, PredictResponse* response,
54  std::function<void(absl::Status status)> callback);
55 
56  static constexpr char kGoodModel[] = "good_model";
57  static constexpr char kBadModel[] = "bad_model";
58 
59  private:
60  MockPredictionService(const string& target_address);
61 };
62 
63 constexpr char MockPredictionService::kGoodModel[];
64 constexpr char MockPredictionService::kBadModel[];
65 
66 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> AliasTensorMap;
67 
68 MockPredictionService::MockPredictionService(const string& target_address) {}
69 
70 void MockPredictionService::Predict(
71  MockRpc* rpc, PredictRequest* request, PredictResponse* response,
72  std::function<void(absl::Status status)> callback) {
73  // Use model name to specify the behavior of each test.
74  std::string model_name = request->model_spec().name();
75  if (model_name == kGoodModel) {
76  *(response->mutable_model_spec()) = request->model_spec();
77  AliasTensorMap& inputs = *request->mutable_inputs();
78  AliasTensorMap& outputs = *response->mutable_outputs();
79  outputs["output0"] = inputs["input0"];
80  outputs["output1"] = inputs["input1"];
81  callback(::absl::OkStatus());
82  }
83 
84  if (model_name == kBadModel) {
85  callback(absl::Status(absl::StatusCode::kAborted, "Aborted"));
86  }
87 }
88 
89 REGISTER_KERNEL_BUILDER(Name("TfServingRemotePredict").Device(DEVICE_CPU),
90  RemotePredictOp<MockPredictionService>);
91 
92 using RemotePredict = ops::TfServingRemotePredict;
93 
94 // Use model_name to specify the behavior of different tests.
95 ::tensorflow::Status RunRemotePredict(
96  const string& model_name, std::vector<Tensor>* outputs,
97  const DataTypeSlice& output_types = {DT_INT32, DT_INT32},
98  const absl::optional<::absl::Duration> deadline = absl::nullopt,
99  bool fail_on_rpc_error = true,
100  const string& target_address = "target_address",
101  int64_t target_model_version = -1, const string& signature_name = "") {
102  const Scope scope = Scope::DisabledShapeInferenceScope();
103  // Model_name will decide the result of the RPC.
104  auto input_tensor_aliases = ops::Const(
105  scope.WithOpName("input_tensor_aliases"), {"input0", "input1"});
106  auto input_tensors0 = ops::Const(scope.WithOpName("input_tensors0"), {1, 2});
107  auto input_tensors1 = ops::Const(scope.WithOpName("input_tensors1"), {3, 4});
108  auto output_tensor_aliases = ops::Const(
109  scope.WithOpName("output_tensor_aliases"), {"output0", "output1"});
110  std::vector<Output> fetch_outputs;
111  RemotePredict::Attrs attrs = RemotePredict::Attrs()
112  .TargetAddress(target_address)
113  .ModelName(model_name)
114  .SignatureName(signature_name);
115 
116  if (target_model_version >= 0) {
117  attrs = attrs.ModelVersion(target_model_version);
118  }
119  if (deadline.has_value()) {
120  attrs = attrs.MaxRpcDeadlineMillis(absl::ToInt64Seconds(deadline.value()) *
121  1000);
122  }
123  attrs = attrs.FailOpOnRpcError(fail_on_rpc_error);
124 
125  auto remote_predict = RemotePredict(
126  scope, input_tensor_aliases, {input_tensors0, input_tensors1},
127  output_tensor_aliases, output_types, attrs);
128 
129  fetch_outputs = {remote_predict.status_code,
130  remote_predict.status_error_message};
131  fetch_outputs.insert(fetch_outputs.end(),
132  remote_predict.output_tensors.begin(),
133  remote_predict.output_tensors.end());
134  TF_RETURN_IF_ERROR(scope.status());
135 
136  ClientSession session(scope);
137  return session.Run(fetch_outputs, outputs);
138 }
139 
140 TEST(RemotePredictTest, TestSimple) {
141  std::vector<Tensor> outputs;
142  TF_ASSERT_OK(RunRemotePredict(
143  /*model_name=*/MockPredictionService::kGoodModel, &outputs));
144  ASSERT_EQ(4, outputs.size());
145  // Checks whether the status code is 0 and there is no error message.
146  EXPECT_EQ(0, outputs[0].scalar<int>()());
147  EXPECT_EQ("", outputs[1].scalar<tensorflow::tstring>()());
148  test::ExpectTensorEqual<int>(outputs[2], test::AsTensor<int>({1, 2}));
149  test::ExpectTensorEqual<int>(outputs[3], test::AsTensor<int>({3, 4}));
150 }
151 
152 TEST(RemotePredictTest, TestRpcError) {
153  std::vector<Tensor> outputs;
154  const auto status = RunRemotePredict(
155  /*model_name=*/MockPredictionService::kBadModel, &outputs);
156  ASSERT_FALSE(status.ok());
157  EXPECT_EQ(error::Code::ABORTED, status.code());
158  EXPECT_THAT(status.message(), ::testing::HasSubstr("Aborted"));
159 }
160 
161 TEST(RemotePredictTest, TestRpcErrorReturnStatus) {
162  std::vector<Tensor> outputs;
163  // Specifying output_types to float solves
164  // "MemorySanitizer: use-of-uninitialized-value"
165  const auto status = RunRemotePredict(
166  /*model_name=*/MockPredictionService::kBadModel, &outputs,
167  {DT_FLOAT, DT_FLOAT}, /*deadline=*/absl::nullopt,
168  /*fail_on_rpc_error=*/false);
169 
170  EXPECT_TRUE(status.ok());
171  EXPECT_EQ(static_cast<int>(error::Code::ABORTED), outputs[0].scalar<int>()());
172  EXPECT_EQ("Aborted", outputs[1].scalar<tensorflow::tstring>()());
173 }
174 
175 } // namespace
176 } // namespace serving
177 } // namespace tensorflow