16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_warmup.h"
20 #include "google/protobuf/wrappers.pb.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/status/status.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/record_reader.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/protobuf/config.pb.h"
32 #include "tensorflow_serving/apis/classification.pb.h"
33 #include "tensorflow_serving/apis/inference.pb.h"
34 #include "tensorflow_serving/apis/predict.pb.h"
35 #include "tensorflow_serving/apis/prediction_log.pb.h"
36 #include "tensorflow_serving/apis/regression.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
38 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
39 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
40 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
43 namespace tensorflow {
47 Status RunWarmupRequest(
const PredictionLog& warmup_record,
48 const tfrt::SavedModel::RunOptions& run_options,
49 int lazy_init_threshold,
50 bool skip_warmup_requests_if_initialized,
51 tfrt::SavedModel* saved_model) {
57 if (skip_warmup_requests_if_initialized &&
58 saved_model->GetMetaGraphDef().signature_def_size() <=
59 lazy_init_threshold &&
60 warmup_record.log_type_case() != PredictionLog::kMultiInferenceLog) {
61 return absl::OkStatus();
64 switch (warmup_record.log_type_case()) {
65 case PredictionLog::kPredictLog: {
66 PredictResponse response;
67 TF_RETURN_IF_ERROR(RunPredict(run_options, {}, saved_model,
68 warmup_record.predict_log().request(),
71 case PredictionLog::kPredictStreamedLog: {
72 if (warmup_record.predict_streamed_log().request_size() == 0) {
73 return absl::InvalidArgumentError(absl::StrCat(
74 "predict_streamed_log does not contain any requests."));
76 if (warmup_record.predict_streamed_log().request_size() > 1) {
77 return absl::InvalidArgumentError(
78 absl::StrCat(
"predict_streamed_log contains more than one request, "
79 "which is not supported by PredictStreamed."));
81 PredictResponse response;
82 auto run_opts = run_options;
83 run_opts.streamed_output_callback =
84 [](absl::flat_hash_map<std::string, tensorflow::Tensor>) {};
85 TF_RETURN_IF_ERROR(RunPredict(
86 run_opts, {}, saved_model,
87 warmup_record.predict_streamed_log().request(0), &response));
89 case PredictionLog::kClassifyLog: {
90 ClassificationResponse response;
91 TF_RETURN_IF_ERROR(RunClassify(run_options, {}, saved_model,
92 warmup_record.classify_log().request(),
96 case PredictionLog::kRegressLog: {
97 RegressionResponse response;
98 TF_RETURN_IF_ERROR(RunRegress(run_options, {}, saved_model,
99 warmup_record.regress_log().request(),
103 case PredictionLog::kMultiInferenceLog: {
104 MultiInferenceResponse response;
105 TF_RETURN_IF_ERROR(RunMultiInference(
106 run_options, {}, saved_model,
107 warmup_record.multi_inference_log().request(), &response));
111 return errors::Unimplemented(strings::StrCat(
112 "Unsupported log_type for warmup: ", warmup_record.log_type_case()));
115 return absl::OkStatus();
120 Status RunSavedModelWarmup(
const ModelWarmupOptions& model_warmup_options,
121 const string& export_dir,
int lazy_init_threshold,
122 bool skip_warmup_requests_if_initialized,
123 tfrt::SavedModel* saved_model) {
124 tfrt::SavedModel::RunOptions run_options;
125 return internal::RunSavedModelWarmup(
126 model_warmup_options, export_dir, [&](PredictionLog prediction_log) {
127 return RunWarmupRequest(
128 prediction_log, run_options, lazy_init_threshold,
129 skip_warmup_requests_if_initialized, saved_model);