TensorFlow Serving C++ API Documentation
resnet_client.cc
1 /* Copyright 2017 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <setjmp.h>
17 
18 #include <fstream>
19 #include <iostream>
20 #include <memory>
21 #include <vector>
22 
23 #include "grpcpp/create_channel.h"
24 #include "grpcpp/security/credentials.h"
25 #include "google/protobuf/map.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/platform/jpeg.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/command_line_flags.h"
30 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
31 
32 using grpc::Channel;
33 using grpc::ClientContext;
34 using grpc::Status;
35 
36 using tensorflow::serving::PredictRequest;
37 using tensorflow::serving::PredictResponse;
38 using tensorflow::serving::PredictionService;
39 
40 typedef google::protobuf::Map<tensorflow::string, tensorflow::TensorProto> OutMap;
41 
43  struct jpeg_error_mgr pub;
44  jmp_buf setjmp_buffer;
45 };
46 
47 typedef struct tf_jpeg_error_mgr* tf_jpeg_error_ptr;
48 
49 METHODDEF(void)
50 tf_jpeg_error_exit(j_common_ptr cinfo) {
51  tf_jpeg_error_ptr tf_jpeg_err = (tf_jpeg_error_ptr)cinfo->err;
52 
53  (*cinfo->err->output_message)(cinfo);
54 
55  longjmp(tf_jpeg_err->setjmp_buffer, 1);
56 }
57 
59  public:
60  // JPEG decompression code following libjpeg-turbo documentation:
61  // https://github.com/libjpeg-turbo/libjpeg-turbo/blob/main/example.txt
62  int readJPEG(const char* file_name, tensorflow::TensorProto* proto) {
63  struct tf_jpeg_error_mgr jerr;
64  FILE* infile;
65  JSAMPARRAY buffer;
66  int row_stride;
67  struct jpeg_decompress_struct cinfo;
68 
69  if ((infile = fopen(file_name, "rb")) == NULL) {
70  fprintf(stderr, "can't open %s\n", file_name);
71  return -1;
72  }
73 
74  cinfo.err = jpeg_std_error(&jerr.pub);
75  jerr.pub.error_exit = tf_jpeg_error_exit;
76  if (setjmp(jerr.setjmp_buffer)) {
77  jpeg_destroy_decompress(&cinfo);
78  fclose(infile);
79  return -1;
80  }
81 
82  jpeg_create_decompress(&cinfo);
83  jpeg_stdio_src(&cinfo, infile);
84 
85  (void)jpeg_read_header(&cinfo, TRUE);
86 
87  (void)jpeg_start_decompress(&cinfo);
88  row_stride = cinfo.output_width * cinfo.output_components;
89  CHECK(cinfo.output_components == 3)
90  << "Only 3-channel (RGB) JPEG files are supported";
91 
92  buffer = (*cinfo.mem->alloc_sarray)((j_common_ptr)&cinfo, JPOOL_IMAGE,
93  row_stride, 1);
94 
95  proto->set_dtype(tensorflow::DataType::DT_FLOAT);
96  while (cinfo.output_scanline < cinfo.output_height) {
97  (void)jpeg_read_scanlines(&cinfo, buffer, 1);
98  for (size_t i = 0; i < cinfo.output_width; i++) {
99  proto->add_float_val(buffer[0][i * 3] / 255.0);
100  proto->add_float_val(buffer[0][i * 3 + 1] / 255.0);
101  proto->add_float_val(buffer[0][i * 3 + 2] / 255.0);
102  }
103  }
104 
105  proto->mutable_tensor_shape()->add_dim()->set_size(1);
106  proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_height);
107  proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_width);
108  proto->mutable_tensor_shape()->add_dim()->set_size(cinfo.output_components);
109 
110  (void)jpeg_finish_decompress(&cinfo);
111 
112  jpeg_destroy_decompress(&cinfo);
113  fclose(infile);
114  return 0;
115  }
116 
117  ServingClient(std::shared_ptr<Channel> channel)
118  : stub_(PredictionService::NewStub(channel)) {}
119 
120  tensorflow::string callPredict(const tensorflow::string& model_name,
121  const tensorflow::string& model_signature_name,
122  const tensorflow::string& input_name,
123  const tensorflow::string& file_path) {
124  PredictRequest predictRequest;
125  PredictResponse response;
126  ClientContext context;
127 
128  predictRequest.mutable_model_spec()->set_name(model_name);
129  predictRequest.mutable_model_spec()->set_signature_name(
130  model_signature_name);
131 
132  google::protobuf::Map<tensorflow::string, tensorflow::TensorProto>& inputs =
133  *predictRequest.mutable_inputs();
134 
135  tensorflow::TensorProto proto;
136 
137  const char* infile = file_path.c_str();
138 
139  if (readJPEG(infile, &proto)) {
140  std::cout << "error constructing the protobuf";
141  return "execution failed";
142  }
143 
144  inputs[input_name] = proto;
145 
146  Status status = stub_->Predict(&context, predictRequest, &response);
147 
148  if (status.ok()) {
149  std::cout << "call predict ok" << std::endl;
150  std::cout << "outputs size is " << response.outputs_size() << std::endl;
151  OutMap& map_outputs = *response.mutable_outputs();
152  OutMap::iterator iter;
153  int output_index = 0;
154 
155  for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
156  tensorflow::TensorProto& result_tensor_proto = iter->second;
157  tensorflow::Tensor tensor;
158  bool converted = tensor.FromProto(result_tensor_proto);
159  if (converted) {
160  std::cout << "the result tensor[" << output_index
161  << "] is:" << std::endl
162  << tensor.SummarizeValue(1001) << std::endl;
163  } else {
164  std::cout << "the result tensor[" << output_index
165  << "] convert failed." << std::endl;
166  }
167  ++output_index;
168  }
169  return "Done.";
170  } else {
171  std::cout << "gRPC call return code: " << status.error_code() << ": "
172  << status.error_message() << std::endl;
173  return "gRPC failed.";
174  }
175  }
176 
177  private:
178  std::unique_ptr<PredictionService::Stub> stub_;
179 };
180 
181 int main(int argc, char** argv) {
182  tensorflow::string server_port = "localhost:8500";
183  tensorflow::string image_file = "";
184  tensorflow::string model_name = "resnet";
185  tensorflow::string model_signature_name = "serving_default";
186  tensorflow::string input_name = "input_1";
187  std::vector<tensorflow::Flag> flag_list = {
188  tensorflow::Flag("server_port", &server_port,
189  "the IP and port of the server"),
190  tensorflow::Flag("image_file", &image_file, "the path to the image"),
191  tensorflow::Flag("model_name", &model_name, "name of model"),
192  tensorflow::Flag("model_signature_name", &model_signature_name,
193  "name of model signature"),
194  tensorflow::Flag("input_name", &input_name, "name of input tensor")};
195 
196  tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
197  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
198  if (!parse_result || image_file.empty()) {
199  std::cout << usage;
200  return -1;
201  }
202 
203  ServingClient guide(
204  grpc::CreateChannel(server_port, grpc::InsecureChannelCredentials()));
205  std::cout << "calling predict using file: " << image_file << " ..."
206  << std::endl;
207  std::cout << guide.callPredict(model_name, model_signature_name, input_name,
208  image_file)
209  << std::endl;
210  return 0;
211 }