24 #include "grpcpp/create_channel.h"
25 #include "grpcpp/security/credentials.h"
26 #include "absl/flags/flag.h"
27 #include "absl/flags/parse.h"
28 #include "absl/strings/match.h"
29 #include "xla/tsl/lib/histogram/histogram.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/env_time.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
35 using grpc::ClientAsyncResponseReader;
36 using grpc::ClientContext;
37 using grpc::CompletionQueue;
40 using tensorflow::Env;
41 using tensorflow::EnvTime;
42 using tensorflow::Thread;
43 using tensorflow::protobuf::TextFormat;
44 using tensorflow::serving::PredictRequest;
45 using tensorflow::serving::PredictResponse;
46 using tensorflow::serving::PredictionService;
47 using tsl::histogram::ThreadSafeHistogram;
49 ABSL_FLAG(std::string, server_port,
"",
"Target server (host:port)");
50 ABSL_FLAG(std::string, request,
"",
"File containing request proto message");
51 ABSL_FLAG(std::string, model_name,
"",
"Model name to override in the request");
52 ABSL_FLAG(
int, model_version, -1,
"Model version to override in the request");
53 ABSL_FLAG(
int, num_requests, 1,
"Total number of requests to send.");
54 ABSL_FLAG(
int, qps, 1,
"Rate for sending requests.");
55 ABSL_FLAG(
int, rpc_deadline_ms, 1000,
"RPC request deadline in milliseconds.");
56 ABSL_FLAG(
bool, print_rpc_errors,
false,
"Print RPC errors.");
58 bool ReadProtoFromFile(
const std::string& filename, PredictRequest* req) {
59 auto in = std::ifstream(filename);
60 if (!in)
return false;
61 std::ostringstream ss;
63 return absl::EndsWith(filename,
".pbtxt")
64 ? TextFormat::ParseFromString(ss.str(), req)
65 : req->ParseFromString(ss.str());
71 : stub_(PredictionService::NewStub(grpc::CreateChannel(
72 server_port, grpc::InsecureChannelCredentials()))),
76 latency_histogram_(new ThreadSafeHistogram()),
77 error_histogram_(new ThreadSafeHistogram(
79 {0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})) {
80 thread_.reset(Env::Default()->StartThread({},
"reaprpcs",
81 [
this]() { ReapCompletions(); }));
84 void IssuePredict(
const PredictRequest& req) {
85 auto state =
new RpcState();
86 state->context.set_deadline(
87 std::chrono::system_clock::now() +
88 std::chrono::milliseconds(absl::GetFlag(FLAGS_rpc_deadline_ms)));
89 state->start_time = EnvTime::NowMicros();
90 std::unique_ptr<ClientAsyncResponseReader<PredictResponse>> rpc(
91 stub_->AsyncPredict(&state->context, req, &cq_));
92 rpc->Finish(&state->resp, &state->status, (
void*)state);
95 void ReapCompletions() {
98 while (cq_.Next(&sp, &ok)) {
100 std::unique_ptr<RpcState> state((RpcState*)sp);
101 if (state->status.ok()) {
103 latency_histogram_->Add(EnvTime::NowMicros() - state->start_time);
106 error_histogram_->Add(state->status.error_code());
107 if (absl::GetFlag(FLAGS_print_rpc_errors)) {
108 std::cerr <<
"ERROR: RPC failed code: " << state->status.error_code()
109 <<
" msg: " << state->status.error_message() << std::endl;
115 void WaitForCompletion(
int total_rpcs) {
116 while (done_count_ < total_rpcs) {
117 Env::Default()->SleepForMicroseconds(1000);
124 if (success_count_) {
125 std::cout <<
"Request stats (successful)" << std::endl;
126 std::cout << latency_histogram_->ToString() << std::endl;
129 std::cout <<
"Request stats (errors)" << std::endl;
130 std::cout << error_histogram_->ToString() << std::endl;
137 ClientContext context;
138 PredictResponse resp;
141 std::unique_ptr<PredictionService::Stub> stub_;
143 std::unique_ptr<Thread> thread_;
144 std::atomic<int> done_count_;
145 std::atomic<int> success_count_;
146 std::atomic<int> error_count_;
147 std::unique_ptr<ThreadSafeHistogram> latency_histogram_;
148 std::unique_ptr<ThreadSafeHistogram> error_histogram_;
151 int main(
int argc,
char** argv) {
152 absl::ParseCommandLine(argc, argv);
154 if (absl::GetFlag(FLAGS_server_port).empty() ||
155 absl::GetFlag(FLAGS_request).empty()) {
156 std::cerr <<
"ERROR: --server_port and --request flags are required."
162 if (!ReadProtoFromFile(absl::GetFlag(FLAGS_request), &req)) {
163 std::cerr <<
"ERROR: Failed to parse protobuf from file: "
164 << absl::GetFlag(FLAGS_request) << std::endl;
167 if (!absl::GetFlag(FLAGS_model_name).empty()) {
168 req.mutable_model_spec()->set_name(absl::GetFlag(FLAGS_model_name));
170 if (absl::GetFlag(FLAGS_model_version) >= 0) {
171 req.mutable_model_spec()->mutable_version()->set_value(
172 absl::GetFlag(FLAGS_model_version));
176 std::cout <<
"Sending " << absl::GetFlag(FLAGS_num_requests)
177 <<
" requests to " << absl::GetFlag(FLAGS_server_port) <<
" at "
178 << absl::GetFlag(FLAGS_qps) <<
" requests/sec." << std::endl;
179 for (
int i = 0; i < absl::GetFlag(FLAGS_num_requests); i++) {
180 client.IssuePredict(req);
181 Env::Default()->SleepForMicroseconds(1000000 / absl::GetFlag(FLAGS_qps));
184 std::cout <<
"Waiting for " << absl::GetFlag(FLAGS_num_requests)
185 <<
" requests to complete..." << std::endl;
186 client.WaitForCompletion(absl::GetFlag(FLAGS_num_requests));