16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup.h"
18 #include "google/protobuf/wrappers.pb.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/cc/saved_model/constants.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/io/path.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/monitoring/sampler.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow_serving/apis/prediction_log.pb.h"
28 #include "tensorflow_serving/servables/tensorflow/classifier.h"
29 #include "tensorflow_serving/servables/tensorflow/multi_inference.h"
30 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
31 #include "tensorflow_serving/servables/tensorflow/regressor.h"
32 #include "tensorflow_serving/servables/tensorflow/util.h"
34 namespace tensorflow {
39 Status RunWarmupRequest(
const PredictionLog& warmup_record,
40 const RunOptions& run_options,
41 const MetaGraphDef& meta_graph_def, Session* session) {
42 switch (warmup_record.log_type_case()) {
43 case PredictionLog::kRegressLog: {
44 RegressionResponse response;
45 TF_RETURN_IF_ERROR(RunRegress(run_options, meta_graph_def, {}, session,
46 warmup_record.regress_log().request(),
49 case PredictionLog::kClassifyLog: {
50 ClassificationResponse response;
51 TF_RETURN_IF_ERROR(RunClassify(run_options, meta_graph_def, {}, session,
52 warmup_record.classify_log().request(),
55 case PredictionLog::kPredictLog: {
56 PredictResponse response;
57 TF_RETURN_IF_ERROR(RunPredict(run_options, meta_graph_def, {}, session,
58 warmup_record.predict_log().request(),
61 case PredictionLog::kMultiInferenceLog: {
62 MultiInferenceResponse response;
63 TF_RETURN_IF_ERROR(RunMultiInference(
64 run_options, meta_graph_def, {}, session,
65 warmup_record.multi_inference_log().request(), &response));
67 case PredictionLog::kPredictStreamedLog:
68 return errors::Unimplemented(strings::StrCat(
69 "Unsupported log_type for warmup: ", warmup_record.log_type_case()));
70 case PredictionLog::kSessionRunLog:
71 return errors::Unimplemented(strings::StrCat(
72 "Unsupported log_type for warmup: ", warmup_record.log_type_case()));
76 return absl::OkStatus();
81 Status RunSavedModelWarmup(
const ModelWarmupOptions& model_warmup_options,
82 const RunOptions& run_options,
83 const string& export_dir, SavedModelBundle* bundle) {
84 return internal::RunSavedModelWarmup(
85 model_warmup_options, export_dir, [&](PredictionLog prediction_log) {
86 return RunWarmupRequest(prediction_log, run_options,
87 bundle->meta_graph_def, bundle->GetSession());