TensorFlow Serving C++ API Documentation
saved_model_warmup_test_util.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/signature_constants.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow_serving/apis/prediction_log.pb.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 
32 void PopulateInferenceTask(const string& model_name,
33  const string& signature_name,
34  const string& method_name, InferenceTask* task) {
35  ModelSpec model_spec;
36  model_spec.set_name(model_name);
37  model_spec.set_signature_name(signature_name);
38  *task->mutable_model_spec() = model_spec;
39  task->set_method_name(method_name);
40 }
41 
42 void PopulateMultiInferenceRequest(MultiInferenceRequest* request) {
43  request->mutable_input()->mutable_example_list()->add_examples();
44  PopulateInferenceTask("test_model", kRegressMethodName, kRegressMethodName,
45  request->add_tasks());
46  PopulateInferenceTask("test_model", kClassifyMethodName, kClassifyMethodName,
47  request->add_tasks());
48 }
49 
50 void PopulatePredictRequest(PredictRequest* request) {
51  request->mutable_model_spec()->set_signature_name(kPredictMethodName);
52  TensorProto tensor_proto;
53  tensor_proto.add_string_val("input_value");
54  tensor_proto.set_dtype(tensorflow::DT_STRING);
55  tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
56  (*request->mutable_inputs())[kPredictInputs] = tensor_proto;
57 }
58 
59 void PopulateClassificationRequest(ClassificationRequest* request) {
60  request->mutable_input()->mutable_example_list()->add_examples();
61  request->mutable_model_spec()->set_signature_name(kClassifyMethodName);
62 }
63 
64 void PopulateRegressionRequest(RegressionRequest* request) {
65  request->mutable_input()->mutable_example_list()->add_examples();
66  request->mutable_model_spec()->set_signature_name(kRegressMethodName);
67 }
68 
69 Status PopulatePredictionLog(PredictionLog* prediction_log,
70  PredictionLog::LogTypeCase log_type,
71  int num_repeated_values) {
72  if ((num_repeated_values > 1) &&
73  (log_type != PredictionLog::kPredictStreamedLog)) {
74  return errors::InvalidArgument(
75  "Only predict_streamed_log supports num_repeated_values > 1.");
76  }
77  switch (log_type) {
78  case PredictionLog::kRegressLog: {
79  PopulateRegressionRequest(
80  prediction_log->mutable_regress_log()->mutable_request());
81  } break;
82  case PredictionLog::kClassifyLog: {
83  PopulateClassificationRequest(
84  prediction_log->mutable_classify_log()->mutable_request());
85  } break;
86  case PredictionLog::kPredictLog: {
87  PopulatePredictRequest(
88  prediction_log->mutable_predict_log()->mutable_request());
89  } break;
90  case PredictionLog::kPredictStreamedLog: {
91  for (int i = 0; i < num_repeated_values; ++i) {
92  PopulatePredictRequest(
93  prediction_log->mutable_predict_streamed_log()->add_request());
94  }
95  } break;
96  case PredictionLog::kMultiInferenceLog: {
97  PopulateMultiInferenceRequest(
98  prediction_log->mutable_multi_inference_log()->mutable_request());
99  } break;
100  case PredictionLog::kSessionRunLog:
101  prediction_log->mutable_session_run_log();
102  TF_FALLTHROUGH_INTENDED;
103  default:
104  break;
105  }
106  return absl::OkStatus();
107 }
108 
109 Status WriteWarmupData(const string& fname,
110  const std::vector<string>& warmup_records,
111  int num_warmup_records) {
112  Env* env = Env::Default();
113  std::unique_ptr<WritableFile> file;
114  TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file));
115 
116  io::RecordWriterOptions options;
117  io::RecordWriter writer(file.get(), options);
118  for (int i = 0; i < num_warmup_records; ++i) {
119  for (const string& warmup_record : warmup_records) {
120  TF_RETURN_IF_ERROR(writer.WriteRecord(warmup_record));
121  }
122  }
123  TF_RETURN_IF_ERROR(writer.Flush());
124  return absl::OkStatus();
125 }
126 
127 Status WriteWarmupDataAsSerializedProtos(
128  const string& fname, const std::vector<string>& warmup_records,
129  int num_warmup_records) {
130  Env* env = Env::Default();
131  std::unique_ptr<WritableFile> file;
132  TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file));
133  for (int i = 0; i < num_warmup_records; ++i) {
134  for (const string& warmup_record : warmup_records) {
135  TF_RETURN_IF_ERROR(file->Append(warmup_record));
136  }
137  }
138  TF_RETURN_IF_ERROR(file->Close());
139  return absl::OkStatus();
140 }
141 
142 Status AddMixedWarmupData(
143  std::vector<string>* warmup_records,
144  const std::vector<PredictionLog::LogTypeCase>& log_types) {
145  for (auto& log_type : log_types) {
146  TF_RETURN_IF_ERROR(AddToWarmupData(warmup_records, log_type, 1));
147  }
148  return absl::OkStatus();
149 }
150 
151 Status AddToWarmupData(std::vector<string>* warmup_records,
152  PredictionLog::LogTypeCase log_type,
153  int num_repeated_values) {
154  PredictionLog prediction_log;
155  TF_RETURN_IF_ERROR(
156  PopulatePredictionLog(&prediction_log, log_type, num_repeated_values));
157  warmup_records->push_back(prediction_log.SerializeAsString());
158  return absl::OkStatus();
159 }
160 
161 // Creates a test SignatureDef with the given parameters
162 SignatureDef CreateSignatureDef(const string& method_name,
163  const std::vector<string>& input_names,
164  const std::vector<string>& output_names) {
165  SignatureDef signature_def;
166  signature_def.set_method_name(method_name);
167  for (const string& input_name : input_names) {
168  TensorInfo input;
169  input.set_name(input_name);
170  (*signature_def.mutable_inputs())[input_name] = input;
171  }
172  for (const string& output_name : output_names) {
173  TensorInfo output;
174  output.set_name(output_name);
175  (*signature_def.mutable_outputs())[output_name] = output;
176  }
177  return signature_def;
178 }
179 
180 void AddSignatures(MetaGraphDef* meta_graph_def) {
181  (*meta_graph_def->mutable_signature_def())[kRegressMethodName] =
182  CreateSignatureDef(kRegressMethodName, {kRegressInputs},
183  {kRegressOutputs});
184  (*meta_graph_def->mutable_signature_def())[kClassifyMethodName] =
185  CreateSignatureDef(kClassifyMethodName, {kClassifyInputs},
186  {kClassifyOutputClasses, kClassifyOutputScores});
187  (*meta_graph_def->mutable_signature_def())[kPredictMethodName] =
188  CreateSignatureDef(kPredictMethodName, {kPredictInputs},
189  {kPredictOutputs});
190 }
191 
192 } // namespace serving
193 } // namespace tensorflow