TensorFlow Serving C++ API Documentation
grpc_client.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <chrono> // NOLINT(build/c++11)
18 #include <fstream>
19 #include <iostream>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 
24 #include "grpcpp/create_channel.h"
25 #include "grpcpp/security/credentials.h"
26 #include "absl/flags/flag.h"
27 #include "absl/flags/parse.h"
28 #include "absl/strings/match.h"
29 #include "xla/tsl/lib/histogram/histogram.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/env_time.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
34 
35 using grpc::ClientAsyncResponseReader;
36 using grpc::ClientContext;
37 using grpc::CompletionQueue;
38 using grpc::Status;
39 
40 using tensorflow::Env;
41 using tensorflow::EnvTime;
42 using tensorflow::Thread;
43 using tensorflow::protobuf::TextFormat;
44 using tensorflow::serving::PredictRequest;
45 using tensorflow::serving::PredictResponse;
46 using tensorflow::serving::PredictionService;
47 using tsl::histogram::ThreadSafeHistogram;
48 
49 ABSL_FLAG(std::string, server_port, "", "Target server (host:port)");
50 ABSL_FLAG(std::string, request, "", "File containing request proto message");
51 ABSL_FLAG(std::string, model_name, "", "Model name to override in the request");
52 ABSL_FLAG(int, model_version, -1, "Model version to override in the request");
53 ABSL_FLAG(int, num_requests, 1, "Total number of requests to send.");
54 ABSL_FLAG(int, qps, 1, "Rate for sending requests.");
55 ABSL_FLAG(int, rpc_deadline_ms, 1000, "RPC request deadline in milliseconds.");
56 ABSL_FLAG(bool, print_rpc_errors, false, "Print RPC errors.");
57 
58 bool ReadProtoFromFile(const std::string& filename, PredictRequest* req) {
59  auto in = std::ifstream(filename);
60  if (!in) return false;
61  std::ostringstream ss;
62  ss << in.rdbuf();
63  return absl::EndsWith(filename, ".pbtxt")
64  ? TextFormat::ParseFromString(ss.str(), req)
65  : req->ParseFromString(ss.str());
66 }
67 
68 class ServingClient {
69  public:
70  ServingClient(const std::string& server_port)
71  : stub_(PredictionService::NewStub(grpc::CreateChannel(
72  server_port, grpc::InsecureChannelCredentials()))),
73  done_count_(0),
74  success_count_(0),
75  error_count_(0),
76  latency_histogram_(new ThreadSafeHistogram()),
77  error_histogram_(new ThreadSafeHistogram(
78  // Range from grpc::StatusCode enum.
79  {0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})) {
80  thread_.reset(Env::Default()->StartThread({}, "reaprpcs",
81  [this]() { ReapCompletions(); }));
82  }
83 
84  void IssuePredict(const PredictRequest& req) {
85  auto state = new RpcState();
86  state->context.set_deadline(
87  std::chrono::system_clock::now() +
88  std::chrono::milliseconds(absl::GetFlag(FLAGS_rpc_deadline_ms)));
89  state->start_time = EnvTime::NowMicros();
90  std::unique_ptr<ClientAsyncResponseReader<PredictResponse>> rpc(
91  stub_->AsyncPredict(&state->context, req, &cq_));
92  rpc->Finish(&state->resp, &state->status, (void*)state);
93  }
94 
95  void ReapCompletions() {
96  void* sp;
97  bool ok = false;
98  while (cq_.Next(&sp, &ok)) {
99  done_count_++;
100  std::unique_ptr<RpcState> state((RpcState*)sp);
101  if (state->status.ok()) {
102  success_count_++;
103  latency_histogram_->Add(EnvTime::NowMicros() - state->start_time);
104  } else {
105  error_count_++;
106  error_histogram_->Add(state->status.error_code());
107  if (absl::GetFlag(FLAGS_print_rpc_errors)) {
108  std::cerr << "ERROR: RPC failed code: " << state->status.error_code()
109  << " msg: " << state->status.error_message() << std::endl;
110  }
111  }
112  }
113  }
114 
115  void WaitForCompletion(int total_rpcs) {
116  while (done_count_ < total_rpcs) {
117  Env::Default()->SleepForMicroseconds(1000);
118  }
119  cq_.Shutdown();
120  thread_.reset();
121  }
122 
123  void DumpStats() {
124  if (success_count_) {
125  std::cout << "Request stats (successful)" << std::endl;
126  std::cout << latency_histogram_->ToString() << std::endl;
127  }
128  if (error_count_) {
129  std::cout << "Request stats (errors)" << std::endl;
130  std::cout << error_histogram_->ToString() << std::endl;
131  }
132  }
133 
134  private:
135  struct RpcState {
136  uint64_t start_time;
137  ClientContext context;
138  PredictResponse resp;
139  Status status;
140  };
141  std::unique_ptr<PredictionService::Stub> stub_;
142  CompletionQueue cq_;
143  std::unique_ptr<Thread> thread_;
144  std::atomic<int> done_count_;
145  std::atomic<int> success_count_;
146  std::atomic<int> error_count_;
147  std::unique_ptr<ThreadSafeHistogram> latency_histogram_;
148  std::unique_ptr<ThreadSafeHistogram> error_histogram_;
149 };
150 
151 int main(int argc, char** argv) {
152  absl::ParseCommandLine(argc, argv);
153 
154  if (absl::GetFlag(FLAGS_server_port).empty() ||
155  absl::GetFlag(FLAGS_request).empty()) {
156  std::cerr << "ERROR: --server_port and --request flags are required."
157  << std::endl;
158  return 1;
159  }
160 
161  PredictRequest req;
162  if (!ReadProtoFromFile(absl::GetFlag(FLAGS_request), &req)) {
163  std::cerr << "ERROR: Failed to parse protobuf from file: "
164  << absl::GetFlag(FLAGS_request) << std::endl;
165  return 1;
166  }
167  if (!absl::GetFlag(FLAGS_model_name).empty()) {
168  req.mutable_model_spec()->set_name(absl::GetFlag(FLAGS_model_name));
169  }
170  if (absl::GetFlag(FLAGS_model_version) >= 0) {
171  req.mutable_model_spec()->mutable_version()->set_value(
172  absl::GetFlag(FLAGS_model_version));
173  }
174 
175  ServingClient client(absl::GetFlag(FLAGS_server_port));
176  std::cout << "Sending " << absl::GetFlag(FLAGS_num_requests)
177  << " requests to " << absl::GetFlag(FLAGS_server_port) << " at "
178  << absl::GetFlag(FLAGS_qps) << " requests/sec." << std::endl;
179  for (int i = 0; i < absl::GetFlag(FLAGS_num_requests); i++) {
180  client.IssuePredict(req);
181  Env::Default()->SleepForMicroseconds(1000000 / absl::GetFlag(FLAGS_qps));
182  }
183 
184  std::cout << "Waiting for " << absl::GetFlag(FLAGS_num_requests)
185  << " requests to complete..." << std::endl;
186  client.WaitForCompletion(absl::GetFlag(FLAGS_num_requests));
187  client.DumpStats();
188  return 0;
189 }