15 #include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/prediction_service_grpc.h"
20 #include "grpcpp/create_channel.h"
21 #include "grpcpp/security/credentials.h"
22 #include "absl/time/clock.h"
24 using namespace tensorflow;
25 namespace tensorflow {
29 absl::Status FromGrpcStatus(const ::grpc::Status& s) {
31 return absl::Status();
33 return absl::Status(
static_cast<absl::StatusCode
>(s.error_code()),
39 PredictionServiceGrpc::PredictionServiceGrpc(
40 const std::string& target_address) {
42 auto channel = ::grpc::CreateChannel(target_address,
43 ::grpc::InsecureChannelCredentials());
44 stub_ = tensorflow::serving::PredictionService::NewStub(channel);
47 absl::StatusOr< ::grpc::ClientContext*> PredictionServiceGrpc::CreateRpc(
48 absl::Duration max_rpc_deadline) {
49 ::grpc::ClientContext* rpc = new ::grpc::ClientContext();
52 rpc->set_deadline(std::chrono::system_clock::now() +
53 absl::ToChronoSeconds(max_rpc_deadline));
57 void PredictionServiceGrpc::Predict(
58 ::grpc::ClientContext* rpc, PredictRequest* request,
59 PredictResponse* response,
60 std::function<
void(absl::Status status)> callback) {
61 std::function<void(::grpc::Status)> wrapped_callback =
62 [callback](::grpc::Status status) { callback(FromGrpcStatus(status)); };
64 stub_->experimental_async()->Predict(rpc, request, response,