30 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/random/philox_random.h"
35 #include "tensorflow/core/lib/random/simple_philox.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/init_main.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/test.h"
42 #include "tensorflow/core/platform/test_benchmark.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow_serving/core/aspired_version_policy.h"
45 #include "tensorflow_serving/core/aspired_versions_manager.h"
46 #include "tensorflow_serving/core/availability_preserving_policy.h"
47 #include "tensorflow_serving/core/loader.h"
48 #include "tensorflow_serving/core/manager.h"
49 #include "tensorflow_serving/core/servable_data.h"
50 #include "tensorflow_serving/core/servable_handle.h"
51 #include "tensorflow_serving/core/simple_loader.h"
52 #include "tensorflow_serving/core/test_util/manager_test_util.h"
54 namespace tensorflow {
58 constexpr
char kServableName[] =
"kServableName";
71 class BenchmarkState {
73 BenchmarkState(
const int interval_micros,
const bool do_work)
74 : interval_micros_(interval_micros), do_work_(do_work) {
75 AspiredVersionsManager::Options options;
77 options.manage_state_interval_micros = -1;
78 options.aspired_version_policy.reset(
new AvailabilityPreservingPolicy());
79 TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager_));
83 void RunBenchmark(::testing::benchmark::State& state,
int num_threads);
90 void RunReads(
int iters);
97 void StartServing(int64_t loader_version);
100 int64_t GetLatestVersion(
bool do_work);
105 Notification all_read_threads_scheduled_;
109 std::unique_ptr<PeriodicFunction> update_thread_;
113 std::unique_ptr<AspiredVersionsManager> manager_;
116 const int interval_micros_;
123 void BenchmarkState::StartServing(
const int64_t loader_version) {
124 std::unique_ptr<Loader> loader(
new SimpleLoader<int64_t>(
125 [loader_version](std::unique_ptr<int64_t>*
const servable) {
126 servable->reset(
new int64_t);
127 **servable = loader_version;
130 SimpleLoader<int64_t>::EstimateNoResources()));
131 std::vector<ServableData<std::unique_ptr<Loader>>> versions;
132 versions.push_back({{kServableName, loader_version}, std::move(loader)});
133 manager_->GetAspiredVersionsCallback()(kServableName, std::move(versions));
134 test_util::AspiredVersionsManagerTestAccess(manager_.get())
135 .HandlePendingAspiredVersionsRequests();
137 test_util::AspiredVersionsManagerTestAccess(manager_.get())
138 .InvokePolicyAndExecuteAction();
140 test_util::AspiredVersionsManagerTestAccess(manager_.get())
141 .InvokePolicyAndExecuteAction();
143 test_util::AspiredVersionsManagerTestAccess(manager_.get())
144 .InvokePolicyAndExecuteAction();
145 CHECK_EQ(1, manager_->ListAvailableServableIds().size());
148 int64_t BenchmarkState::GetLatestVersion(
const bool do_work) {
149 ServableHandle<int64_t> handle;
150 const Status status = manager_->GetServableHandle(
151 ServableRequest::Latest(kServableName), &handle);
152 TF_CHECK_OK(status) << status;
157 for (
int i = 1; i < 10000; ++i) {
166 void BenchmarkState::RunUpdate() { StartServing(GetLatestVersion(
false) + 1); }
168 void BenchmarkState::SetUp() {
170 if (interval_micros_ > 0) {
171 PeriodicFunction::Options pf_options;
172 pf_options.thread_name_prefix =
173 "AspiredVersionsManager_Benchmark_Update_Thread";
174 update_thread_.reset(
new PeriodicFunction([
this] { RunUpdate(); },
175 interval_micros_, pf_options));
179 void BenchmarkState::TearDown() {
181 update_thread_.reset();
184 void BenchmarkState::RunReads(
int iters) {
185 for (
int i = 0; i < iters; ++i) {
187 CHECK_GE(GetLatestVersion(do_work_), 0);
191 void BenchmarkState::RunBenchmark(::testing::benchmark::State& state,
203 const int kSubIters = 500;
208 for (
auto s : state) {
211 std::unique_ptr<thread::ThreadPool> pool(
new thread::ThreadPool(
212 Env::Default(),
"RunBenchmarkReadThread", num_threads));
213 for (
int thread_index = 0; thread_index < num_threads; ++thread_index) {
214 std::function<void()> run_reads_fn = [&]() {
216 all_read_threads_scheduled_.WaitForNotification();
219 pool->Schedule(run_reads_fn);
221 state.ResumeTiming();
222 if (!all_read_threads_scheduled_.HasBeenNotified())
223 all_read_threads_scheduled_.Notify();
233 state.SetItemsProcessed(num_threads * kSubIters * state.iterations());
237 void BenchmarkReadsAndUpdates(::testing::benchmark::State& state,
238 int num_threads,
int interval_micros,
240 BenchmarkState bm_state(interval_micros, do_work);
241 bm_state.RunBenchmark(state, num_threads);
244 void BM_Work_NoUpdates_Reads(::testing::benchmark::State& state) {
245 const int num_threads = state.range(0);
248 BenchmarkReadsAndUpdates(state, num_threads, 0,
true);
251 void BM_Work_FrequentUpdates_Reads(::testing::benchmark::State& state) {
252 const int num_threads = state.range(0);
255 BenchmarkReadsAndUpdates(state, num_threads, 1000,
true);
258 void BM_NoWork_NoUpdates_Reads(::testing::benchmark::State& state) {
259 const int num_threads = state.range(0);
262 BenchmarkReadsAndUpdates(state, num_threads, 0,
false);
265 void BM_NoWork_FrequentUpdates_Reads(::testing::benchmark::State& state) {
266 const int num_threads = state.range(0);
269 BenchmarkReadsAndUpdates(state, num_threads, 1000,
false);
276 BENCHMARK(BM_Work_NoUpdates_Reads)
286 BENCHMARK(BM_Work_FrequentUpdates_Reads)
296 BENCHMARK(BM_NoWork_NoUpdates_Reads)
306 BENCHMARK(BM_NoWork_FrequentUpdates_Reads)
316 void BM_GetServableHandle(::testing::benchmark::State& state) {
318 constexpr
int kNumServableStreams = 10;
320 constexpr
int kNumServableVersions = 2;
322 static AspiredVersionsManager*
const manager = []() {
323 AspiredVersionsManager::Options options;
325 options.manage_state_interval_micros = -1;
326 options.aspired_version_policy.reset(
new AvailabilityPreservingPolicy());
327 std::unique_ptr<AspiredVersionsManager> manager;
328 TF_CHECK_OK(AspiredVersionsManager::Create(std::move(options), &manager));
329 auto aspired_versions_callback = manager->GetAspiredVersionsCallback();
330 for (
int i = 0; i < kNumServableStreams; ++i) {
331 const string servable_name = strings::StrCat(kServableName, i);
332 std::vector<ServableData<std::unique_ptr<Loader>>> versions;
333 for (
int j = 0; j < kNumServableVersions; ++j) {
334 std::unique_ptr<Loader> loader(
new SimpleLoader<int64_t>(
335 [j](std::unique_ptr<int64_t>*
const servable) {
336 servable->reset(
new int64_t);
340 SimpleLoader<int64_t>::EstimateNoResources()));
341 versions.push_back({{servable_name, j}, std::move(loader)});
344 aspired_versions_callback(servable_name, std::move(versions));
345 test_util::AspiredVersionsManagerTestAccess(manager.get())
346 .HandlePendingAspiredVersionsRequests();
347 for (
int j = 0; j < kNumServableVersions; ++j) {
348 test_util::AspiredVersionsManagerTestAccess(manager.get())
349 .InvokePolicyAndExecuteAction();
352 return manager.release();
354 CHECK_EQ(kNumServableStreams * kNumServableVersions,
355 manager->ListAvailableServableIds().size());
357 constexpr
int kNumRequests = 1024;
360 constexpr
float kLatestRatio = 0.8;
361 static const std::vector<ServableRequest>* requests = []() {
362 std::vector<ServableRequest>* requests(
new std::vector<ServableRequest>());
363 random::PhiloxRandom philox(testing::RandomSeed());
364 random::SimplePhilox random(&philox);
365 for (
int i = 0; i < kNumRequests; ++i) {
367 strings::StrCat(kServableName, random.Uniform(kNumServableStreams));
368 if (random.RandFloat() > kLatestRatio) {
369 const int64_t version = random.Uniform(kNumServableVersions);
370 requests->push_back(ServableRequest::Specific(name, version));
372 requests->push_back(ServableRequest::Latest(name));
378 ServableHandle<int64_t> handle;
380 for (
auto s : state) {
381 const Status status =
382 manager->GetServableHandle(requests->at(i % kNumRequests), &handle);
383 TF_CHECK_OK(status) << status;
386 state.SetItemsProcessed(state.iterations());
388 BENCHMARK(BM_GetServableHandle);
394 int main(
int argc,
char** argv) {
395 tensorflow::port::InitMain(argv[0], &argc, &argv);
396 tensorflow::testing::RunBenchmarks();