16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_test_util.h"
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/signature_constants.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow_serving/apis/prediction_log.pb.h"
29 namespace tensorflow {
32 void PopulateInferenceTask(
const string& model_name,
33 const string& signature_name,
34 const string& method_name, InferenceTask* task) {
36 model_spec.set_name(model_name);
37 model_spec.set_signature_name(signature_name);
38 *task->mutable_model_spec() = model_spec;
39 task->set_method_name(method_name);
42 void PopulateMultiInferenceRequest(MultiInferenceRequest* request) {
43 request->mutable_input()->mutable_example_list()->add_examples();
44 PopulateInferenceTask(
"test_model", kRegressMethodName, kRegressMethodName,
45 request->add_tasks());
46 PopulateInferenceTask(
"test_model", kClassifyMethodName, kClassifyMethodName,
47 request->add_tasks());
50 void PopulatePredictRequest(PredictRequest* request) {
51 request->mutable_model_spec()->set_signature_name(kPredictMethodName);
52 TensorProto tensor_proto;
53 tensor_proto.add_string_val(
"input_value");
54 tensor_proto.set_dtype(tensorflow::DT_STRING);
55 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
56 (*request->mutable_inputs())[kPredictInputs] = tensor_proto;
59 void PopulateClassificationRequest(ClassificationRequest* request) {
60 request->mutable_input()->mutable_example_list()->add_examples();
61 request->mutable_model_spec()->set_signature_name(kClassifyMethodName);
64 void PopulateRegressionRequest(RegressionRequest* request) {
65 request->mutable_input()->mutable_example_list()->add_examples();
66 request->mutable_model_spec()->set_signature_name(kRegressMethodName);
69 Status PopulatePredictionLog(PredictionLog* prediction_log,
70 PredictionLog::LogTypeCase log_type,
71 int num_repeated_values) {
72 if ((num_repeated_values > 1) &&
73 (log_type != PredictionLog::kPredictStreamedLog)) {
74 return errors::InvalidArgument(
75 "Only predict_streamed_log supports num_repeated_values > 1.");
78 case PredictionLog::kRegressLog: {
79 PopulateRegressionRequest(
80 prediction_log->mutable_regress_log()->mutable_request());
82 case PredictionLog::kClassifyLog: {
83 PopulateClassificationRequest(
84 prediction_log->mutable_classify_log()->mutable_request());
86 case PredictionLog::kPredictLog: {
87 PopulatePredictRequest(
88 prediction_log->mutable_predict_log()->mutable_request());
90 case PredictionLog::kPredictStreamedLog: {
91 for (
int i = 0; i < num_repeated_values; ++i) {
92 PopulatePredictRequest(
93 prediction_log->mutable_predict_streamed_log()->add_request());
96 case PredictionLog::kMultiInferenceLog: {
97 PopulateMultiInferenceRequest(
98 prediction_log->mutable_multi_inference_log()->mutable_request());
100 case PredictionLog::kSessionRunLog:
101 prediction_log->mutable_session_run_log();
102 TF_FALLTHROUGH_INTENDED;
106 return absl::OkStatus();
109 Status WriteWarmupData(
const string& fname,
110 const std::vector<string>& warmup_records,
111 int num_warmup_records) {
112 Env* env = Env::Default();
113 std::unique_ptr<WritableFile> file;
114 TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file));
116 io::RecordWriterOptions options;
117 io::RecordWriter writer(file.get(), options);
118 for (
int i = 0; i < num_warmup_records; ++i) {
119 for (
const string& warmup_record : warmup_records) {
120 TF_RETURN_IF_ERROR(writer.WriteRecord(warmup_record));
123 TF_RETURN_IF_ERROR(writer.Flush());
124 return absl::OkStatus();
127 Status WriteWarmupDataAsSerializedProtos(
128 const string& fname,
const std::vector<string>& warmup_records,
129 int num_warmup_records) {
130 Env* env = Env::Default();
131 std::unique_ptr<WritableFile> file;
132 TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file));
133 for (
int i = 0; i < num_warmup_records; ++i) {
134 for (
const string& warmup_record : warmup_records) {
135 TF_RETURN_IF_ERROR(file->Append(warmup_record));
138 TF_RETURN_IF_ERROR(file->Close());
139 return absl::OkStatus();
142 Status AddMixedWarmupData(
143 std::vector<string>* warmup_records,
144 const std::vector<PredictionLog::LogTypeCase>& log_types) {
145 for (
auto& log_type : log_types) {
146 TF_RETURN_IF_ERROR(AddToWarmupData(warmup_records, log_type, 1));
148 return absl::OkStatus();
151 Status AddToWarmupData(std::vector<string>* warmup_records,
152 PredictionLog::LogTypeCase log_type,
153 int num_repeated_values) {
154 PredictionLog prediction_log;
156 PopulatePredictionLog(&prediction_log, log_type, num_repeated_values));
157 warmup_records->push_back(prediction_log.SerializeAsString());
158 return absl::OkStatus();
162 SignatureDef CreateSignatureDef(
const string& method_name,
163 const std::vector<string>& input_names,
164 const std::vector<string>& output_names) {
165 SignatureDef signature_def;
166 signature_def.set_method_name(method_name);
167 for (
const string& input_name : input_names) {
169 input.set_name(input_name);
170 (*signature_def.mutable_inputs())[input_name] = input;
172 for (
const string& output_name : output_names) {
174 output.set_name(output_name);
175 (*signature_def.mutable_outputs())[output_name] = output;
177 return signature_def;
180 void AddSignatures(MetaGraphDef* meta_graph_def) {
181 (*meta_graph_def->mutable_signature_def())[kRegressMethodName] =
182 CreateSignatureDef(kRegressMethodName, {kRegressInputs},
184 (*meta_graph_def->mutable_signature_def())[kClassifyMethodName] =
185 CreateSignatureDef(kClassifyMethodName, {kClassifyInputs},
186 {kClassifyOutputClasses, kClassifyOutputScores});
187 (*meta_graph_def->mutable_signature_def())[kPredictMethodName] =
188 CreateSignatureDef(kPredictMethodName, {kPredictInputs},