23 #include "grpcpp/create_channel.h"
24 #include "grpcpp/security/credentials.h"
25 #include "google/protobuf/map.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/platform/jpeg.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
33 using grpc::ClientContext;
36 using tensorflow::serving::PredictRequest;
37 using tensorflow::serving::PredictResponse;
38 using tensorflow::serving::PredictionService;
40 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> OutMap;
43 struct jpeg_error_mgr pub;
44 jmp_buf setjmp_buffer;
50 tf_jpeg_error_exit(j_common_ptr cinfo) {
53 (*cinfo->err->output_message)(cinfo);
55 longjmp(tf_jpeg_err->setjmp_buffer, 1);
62 int readJPEG(
const char* file_name, tensorflow::TensorProto* proto) {
67 struct jpeg_decompress_struct cinfo;
69 if ((infile = fopen(file_name,
"rb")) == NULL) {
70 fprintf(stderr,
"can't open %s\n", file_name);
74 cinfo.err = jpeg_std_error(&jerr.pub);
75 jerr.pub.error_exit = tf_jpeg_error_exit;
76 if (setjmp(jerr.setjmp_buffer)) {
77 jpeg_destroy_decompress(&cinfo);
82 jpeg_create_decompress(&cinfo);
83 jpeg_stdio_src(&cinfo, infile);
85 (void)jpeg_read_header(&cinfo, TRUE);
87 (void)jpeg_start_decompress(&cinfo);
88 row_stride = cinfo.output_width * cinfo.output_components;
89 CHECK(cinfo.output_components == 3)
90 <<
"Only 3-channel (RGB) JPEG files are supported";
92 buffer = (*cinfo.mem->alloc_sarray)((j_common_ptr)&cinfo, JPOOL_IMAGE,
95 proto->set_dtype(tensorflow::DataType::DT_FLOAT);
96 while (cinfo.output_scanline < cinfo.output_height) {
97 (void)jpeg_read_scanlines(&cinfo, buffer, 1);
98 for (
size_t i = 0; i < cinfo.output_width; i++) {
99 proto->add_float_val(buffer[0][i * 3] / 255.0);
100 proto->add_float_val(buffer[0][i * 3 + 1] / 255.0);
101 proto->add_float_val(buffer[0][i * 3 + 2] / 255.0);
105 proto->mutable_tensor_shape()->add_dim()->set_size(1);
106 proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_height);
107 proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_width);
108 proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_components);
110 (void)jpeg_finish_decompress(&cinfo);
112 jpeg_destroy_decompress(&cinfo);
118 : stub_(PredictionService::NewStub(channel)) {}
120 tensorflow::string callPredict(
const tensorflow::string& model_name,
121 const tensorflow::string& model_signature_name,
122 const tensorflow::string& input_name,
123 const tensorflow::string& file_path) {
124 PredictRequest predictRequest;
125 PredictResponse response;
126 ClientContext context;
128 predictRequest.mutable_model_spec()->set_name(model_name);
129 predictRequest.mutable_model_spec()->set_signature_name(
130 model_signature_name);
132 google::protobuf::Map<tensorflow::string, tensorflow::TensorProto>& inputs =
133 *predictRequest.mutable_inputs();
135 tensorflow::TensorProto proto;
137 const char* infile = file_path.c_str();
139 if (readJPEG(infile, &proto)) {
140 std::cout <<
"error constructing the protobuf";
141 return "execution failed";
144 inputs[input_name] = proto;
146 Status status = stub_->Predict(&context, predictRequest, &response);
149 std::cout <<
"call predict ok" << std::endl;
150 std::cout <<
"outputs size is " << response.outputs_size() << std::endl;
151 OutMap& map_outputs = *response.mutable_outputs();
152 OutMap::iterator iter;
153 int output_index = 0;
155 for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
156 tensorflow::TensorProto& result_tensor_proto = iter->second;
157 tensorflow::Tensor tensor;
158 bool converted = tensor.FromProto(result_tensor_proto);
160 std::cout <<
"the result tensor[" << output_index
161 <<
"] is:" << std::endl
162 << tensor.SummarizeValue(1001) << std::endl;
164 std::cout <<
"the result tensor[" << output_index
165 <<
"] convert failed." << std::endl;
171 std::cout <<
"gRPC call return code: " << status.error_code() <<
": "
172 << status.error_message() << std::endl;
173 return "gRPC failed.";
178 std::unique_ptr<PredictionService::Stub> stub_;
181 int main(
int argc,
char** argv) {
182 tensorflow::string server_port =
"localhost:8500";
183 tensorflow::string image_file =
"";
184 tensorflow::string model_name =
"resnet";
185 tensorflow::string model_signature_name =
"serving_default";
186 tensorflow::string input_name =
"input_1";
187 std::vector<tensorflow::Flag> flag_list = {
188 tensorflow::Flag(
"server_port", &server_port,
189 "the IP and port of the server"),
190 tensorflow::Flag(
"image_file", &image_file,
"the path to the image"),
191 tensorflow::Flag(
"model_name", &model_name,
"name of model"),
192 tensorflow::Flag(
"model_signature_name", &model_signature_name,
193 "name of model signature"),
194 tensorflow::Flag(
"input_name", &input_name,
"name of input tensor")};
196 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
197 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
198 if (!parse_result || image_file.empty()) {
204 grpc::CreateChannel(server_port, grpc::InsecureChannelCredentials()));
205 std::cout <<
"calling predict using file: " << image_file <<
" ..."
207 std::cout << guide.callPredict(model_name, model_signature_name, input_name,