16 #include "tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
23 #include "google/protobuf/wrappers.pb.h"
24 #include "tensorflow/cc/saved_model/constants.h"
25 #include "tensorflow/core/kernels/batching_util/warmup.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/record_reader.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tsl/platform/errors.h"
33 #include "tensorflow_serving/util/threadpool_executor.h"
35 namespace tensorflow {
40 auto* model_warm_up_latency = monitoring::Sampler<2>::New(
42 "/tensorflow/serving/model_warmup_latency",
43 "Distribution of wall time (in microseconds) for warming up the model.",
47 monitoring::Buckets::Exponential(10, 1.8, 33));
49 uint64_t GetLatencyMicroseconds(
const uint64_t start_microseconds) {
50 const uint64_t end_microseconds = EnvTime::NowMicros();
52 if (end_microseconds < start_microseconds)
return 0;
53 return end_microseconds - start_microseconds;
58 constexpr
char WarmupConsts::kRequestsFileName[];
59 constexpr
int WarmupConsts::kMaxNumRecords;
61 Status RunSavedModelWarmup(
62 const ModelWarmupOptions& model_warmup_options,
const string export_dir,
63 std::function<Status(PredictionLog)> warmup_request_executor) {
64 WarmupStateRegistry::Handle warmup_handle;
65 auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
66 per_model_data->warmup_all_batch_sizes =
67 model_warmup_options.enable_all_batch_sizes_warmup();
68 if (!model_warmup_options.model_name().empty()) {
69 auto h = GetGlobalWarmupStateRegistry().Register(
70 {model_warmup_options.model_name(),
71 model_warmup_options.model_version()},
72 std::move(per_model_data));
73 TF_RETURN_IF_ERROR(h.status());
74 warmup_handle = std::move(h.value());
77 const uint64_t start_microseconds = EnvTime::NowMicros();
78 const string warmup_path =
79 io::JoinPath(export_dir, kSavedModelAssetsExtraDirectory,
80 WarmupConsts::kRequestsFileName);
81 if (!tensorflow::Env::Default()->FilesExist({warmup_path},
nullptr)) {
82 LOG(INFO) <<
"No warmup data file found at " << warmup_path;
84 return absl::OkStatus();
86 const int num_request_iterations = [&]() {
87 if (model_warmup_options.has_num_request_iterations()) {
88 return model_warmup_options.num_request_iterations().value();
93 LOG(INFO) <<
"Starting to read warmup data for model at " << warmup_path
94 <<
" with model-warmup-options "
95 << model_warmup_options.DebugString();
96 std::unique_ptr<tensorflow::RandomAccessFile> tf_record_file;
97 TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewRandomAccessFile(
98 warmup_path, &tf_record_file));
100 int num_model_warmup_threads =
101 model_warmup_options.has_num_model_warmup_threads()
102 ? std::max(model_warmup_options.num_model_warmup_threads().value(), 1)
104 std::unique_ptr<tensorflow::io::SequentialRecordReader> tf_record_file_reader;
106 int num_warmup_records = 0;
107 if (num_model_warmup_threads <= 1) {
108 tf_record_file_reader.reset(
109 new tensorflow::io::SequentialRecordReader(tf_record_file.get()));
111 status = tf_record_file_reader->ReadRecord(&record);
112 tensorflow::serving::PredictionLog prediction_log;
113 while (status.ok()) {
114 if (!prediction_log.ParseFromArray(record.data(), record.size())) {
115 return errors::InvalidArgument(strings::StrCat(
116 "Failed to parse warmup record: ", record,
" from ", warmup_path));
119 for (
int i = 0; i < num_request_iterations; ++i) {
120 TF_RETURN_IF_ERROR(warmup_request_executor(prediction_log));
122 ++num_warmup_records;
123 if (num_warmup_records > WarmupConsts::kMaxNumRecords) {
124 return errors::InvalidArgument(
125 "Number of warmup records exceeds the maximum (",
126 WarmupConsts::kMaxNumRecords,
") at ", warmup_path);
128 status = tf_record_file_reader->ReadRecord(&record);
132 ::tensorflow::mutex mu;
133 int num_thread_task_done ABSL_GUARDED_BY(mu){0};
134 int num_warmup_records ABSL_GUARDED_BY(mu){0};
135 ::tensorflow::Status warm_up_status ABSL_GUARDED_BY(mu);
138 ::tensorflow::condition_variable done ABSL_GUARDED_BY(mu);
139 std::unique_ptr<tensorflow::io::SequentialRecordReader>
140 tf_record_file_reader ABSL_GUARDED_BY(mu);
142 const auto state = std::make_shared<SharedState>();
144 std::unique_ptr<Executor> executor;
145 executor.reset(
new ThreadPoolExecutor(Env::Default(),
"Warmup_ThreadPool",
146 num_model_warmup_threads));
148 ::tensorflow::mutex_lock lock(state->mu);
149 state->tf_record_file_reader.reset(
150 new tensorflow::io::SequentialRecordReader(tf_record_file.get()));
152 for (
int i = 0; i < num_model_warmup_threads; ++i) {
153 executor->Schedule([state, num_request_iterations,
154 warmup_request_executor, warmup_path,
155 num_model_warmup_threads]() {
156 Status status = absl::OkStatus();
157 while (status.ok()) {
159 Status execution_status;
160 tensorflow::serving::PredictionLog prediction_log;
162 ::tensorflow::mutex_lock lock(state->mu);
163 if (!state->warm_up_status.ok()) {
166 if (state->num_warmup_records > WarmupConsts::kMaxNumRecords) {
167 state->warm_up_status = errors::InvalidArgument(
168 "Number of warmup records exceeds the maximum (",
169 WarmupConsts::kMaxNumRecords,
") at ", warmup_path);
173 state->tf_record_file_reader->ReadRecord(&record);
174 if (!execution_status.ok()) {
175 state->warm_up_status = execution_status;
178 if (!prediction_log.ParseFromArray(record.data(), record.size())) {
179 state->warm_up_status = errors::InvalidArgument(
180 strings::StrCat(
"Failed to parse warmup record: ", record,
181 " from ", warmup_path));
185 for (
int i = 0; i < num_request_iterations; ++i) {
186 execution_status = warmup_request_executor(prediction_log);
187 if (!execution_status.ok()) {
191 if (!execution_status.ok()) {
192 ::tensorflow::mutex_lock lock(state->mu);
193 state->warm_up_status = execution_status;
197 ::tensorflow::mutex_lock lock(state->mu);
198 ++state->num_warmup_records;
199 status = state->warm_up_status;
201 ::tensorflow::mutex_lock lock(state->mu);
202 if (++state->num_thread_task_done == num_model_warmup_threads) {
203 state->done.notify_one();
208 ::tensorflow::mutex_lock lock(state->mu);
209 while (state->num_thread_task_done < num_model_warmup_threads) {
210 state->done.wait(lock);
212 status = state->warm_up_status;
213 num_warmup_records = state->num_warmup_records;
218 if (errors::IsOutOfRange(status)) {
219 status = absl::OkStatus();
222 const auto warmup_latency = GetLatencyMicroseconds(start_microseconds);
223 model_warm_up_latency->GetCell(export_dir, status.ToString())
224 ->Add(warmup_latency);
226 if (errors::IsDataLoss(status)) {
227 return errors::DataLoss(
229 ". Please verify your warmup data is in TFRecord format.");
232 TF_RETURN_IF_ERROR(status);
234 LOG(INFO) <<
"Finished reading warmup data for model at " << warmup_path
235 <<
". Number of warmup records read: " << num_warmup_records
236 <<
". Elapsed time (microseconds): " << warmup_latency <<
".";
237 return absl::OkStatus();