TensorFlow Serving C++ API Documentation
tfrt_saved_model_warmup.cc
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_warmup.h"
17 
18 #include <string>
19 
20 #include "google/protobuf/wrappers.pb.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/status/status.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/record_reader.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/protobuf/config.pb.h"
32 #include "tensorflow_serving/apis/classification.pb.h"
33 #include "tensorflow_serving/apis/inference.pb.h"
34 #include "tensorflow_serving/apis/predict.pb.h"
35 #include "tensorflow_serving/apis/prediction_log.pb.h"
36 #include "tensorflow_serving/apis/regression.pb.h"
37 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
38 #include "tensorflow_serving/servables/tensorflow/tfrt_multi_inference.h"
39 #include "tensorflow_serving/servables/tensorflow/tfrt_predict_util.h"
40 #include "tensorflow_serving/servables/tensorflow/tfrt_regressor.h"
41 #include "tensorflow_serving/servables/tensorflow/util.h"
42 
43 namespace tensorflow {
44 namespace serving {
45 namespace {
46 
47 Status RunWarmupRequest(const PredictionLog& warmup_record,
48  const tfrt::SavedModel::RunOptions& run_options,
49  int lazy_init_threshold,
50  bool skip_warmup_requests_if_initialized,
51  tfrt::SavedModel* saved_model) {
52  // If the signature defs are already initilized and
53  // skip_warmup_requests_if_initialized is set to true, skip executing warmup
54  // requests. We always execute MultiInference warmup requests as it will
55  // trigger the compilation and initialization for combination of signature
56  // defs, which won't be triggered during model loading.
57  if (skip_warmup_requests_if_initialized &&
58  saved_model->GetMetaGraphDef().signature_def_size() <=
59  lazy_init_threshold &&
60  warmup_record.log_type_case() != PredictionLog::kMultiInferenceLog) {
61  return absl::OkStatus();
62  }
63 
64  switch (warmup_record.log_type_case()) {
65  case PredictionLog::kPredictLog: {
66  PredictResponse response;
67  TF_RETURN_IF_ERROR(RunPredict(run_options, {}, saved_model,
68  warmup_record.predict_log().request(),
69  &response));
70  } break;
71  case PredictionLog::kPredictStreamedLog: {
72  if (warmup_record.predict_streamed_log().request_size() == 0) {
73  return absl::InvalidArgumentError(absl::StrCat(
74  "predict_streamed_log does not contain any requests."));
75  }
76  if (warmup_record.predict_streamed_log().request_size() > 1) {
77  return absl::InvalidArgumentError(
78  absl::StrCat("predict_streamed_log contains more than one request, "
79  "which is not supported by PredictStreamed."));
80  }
81  PredictResponse response;
82  auto run_opts = run_options;
83  run_opts.streamed_output_callback =
84  [](absl::flat_hash_map<std::string, tensorflow::Tensor>) {};
85  TF_RETURN_IF_ERROR(RunPredict(
86  run_opts, {}, saved_model,
87  warmup_record.predict_streamed_log().request(0), &response));
88  } break;
89  case PredictionLog::kClassifyLog: {
90  ClassificationResponse response;
91  TF_RETURN_IF_ERROR(RunClassify(run_options, {}, saved_model,
92  warmup_record.classify_log().request(),
93  &response));
94  break;
95  }
96  case PredictionLog::kRegressLog: {
97  RegressionResponse response;
98  TF_RETURN_IF_ERROR(RunRegress(run_options, {}, saved_model,
99  warmup_record.regress_log().request(),
100  &response));
101  break;
102  }
103  case PredictionLog::kMultiInferenceLog: {
104  MultiInferenceResponse response;
105  TF_RETURN_IF_ERROR(RunMultiInference(
106  run_options, {}, saved_model,
107  warmup_record.multi_inference_log().request(), &response));
108  break;
109  }
110  default:
111  return errors::Unimplemented(strings::StrCat(
112  "Unsupported log_type for warmup: ", warmup_record.log_type_case()));
113  break;
114  }
115  return absl::OkStatus();
116 }
117 
118 } // namespace
119 
120 Status RunSavedModelWarmup(const ModelWarmupOptions& model_warmup_options,
121  const string& export_dir, int lazy_init_threshold,
122  bool skip_warmup_requests_if_initialized,
123  tfrt::SavedModel* saved_model) {
124  tfrt::SavedModel::RunOptions run_options; // Default RunOptions.
125  return internal::RunSavedModelWarmup(
126  model_warmup_options, export_dir, [&](PredictionLog prediction_log) {
127  return RunWarmupRequest(
128  prediction_log, run_options, lazy_init_threshold,
129  skip_warmup_requests_if_initialized, saved_model);
130  });
131 }
132 
133 } // namespace serving
134 } // namespace tensorflow