TensorFlow Serving C++ API Documentation
tfrt_saved_model_with_batching.h
1 /* Copyright 2020 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Batching interface on top of TFRT SavedModel. Subject to change since TFRT
17 // SavedModel API is temporary and experimental.
18 #ifndef TENSORFLOW_SERVING_BATCHING_TFRT_SAVED_MODEL_WITH_BATCHING_H_
19 #define TENSORFLOW_SERVING_BATCHING_TFRT_SAVED_MODEL_WITH_BATCHING_H_
20 
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
26 #include "tensorflow_serving/batching/batching_options.h"
27 #include "tensorflow_serving/batching/batching_session.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 
32 // The batch scheduler task type used for SavedModel batching, for use in batch
33 // scheduler template parameters, e.g.
34 // BasicBatchScheduler<SavedModelBatchingTask>.
35 struct SavedModelBatchingTask;
36 
37 // A function to construct a batch scheduler for SavedModelBatchingTasks from a
38 // process-batch callback.
39 using SavedModelBatchingSchedulerCreator = std::function<Status(
40  std::function<void(std::unique_ptr<Batch<SavedModelBatchingTask>>)>,
41  std::unique_ptr<BatchScheduler<SavedModelBatchingTask>> *)>;
42 
43 // A function name paired with a lambda to create a batch scheduler for Run()
44 // calls matching the function name.
46  absl::string_view func_name;
47  SavedModelBatchingSchedulerCreator scheduler_creator;
48 };
49 
51 
52 // Creates `saved_model_with_batching` that batches requests per function
53 // internally, where the batch scheduler for each function is created according
54 // to `func_name_with_batching_scheduler_creator`. `saved_model` is the
55 // underlying core to run inference logic and must not be null. Upon successful
56 // return, `saved_model_with_batching` should be used in the same way as a
57 // normal SavedModel. Run() call is still synchronized, and all the batching
58 // logic is transparent to the caller.
59 // Also note that the first dimension of all tensors passed to Run() must be
60 // batching dimension.
61 Status CreateSavedModelWithBatching(
62  const SavedModelBatchingOptions &options,
63  const std::vector<FuncNameWithBatchingSchedulerCreator>
64  &func_name_with_batching_scheduler_creator,
65  std::unique_ptr<tfrt::SavedModel> saved_model,
66  std::unique_ptr<tfrt::SavedModel> *saved_model_with_batching);
67 
69  // For monitoring purpose.
70  static std::string Name() { return "tfrt_saved_model_with_batching"; }
71 
72  tfrt::HostContext *host_context;
73  absl::Span<const Tensor> tfrt_inputs;
74  std::vector<Tensor> *tfrt_outputs;
75  tfrt::SavedModel::RunOptions run_options;
76 
77  // If fields below are used, this is a partial task by splitting a large batch
78  // task.
79  std::vector<Tensor> tfrt_partial_inputs;
80 
81  // Status shared by all partial tasks by splitting a large batch task. The
82  // original task succedds only if all partial tasks succeed.
83  ThreadSafeStatus *partial_status = nullptr;
84 };
85 
86 // The default implementation of
87 // `BasicBatchScheduler::Options.split_input_task_func` if corresponding batch
88 // scheduler for a batching session sets
89 // `BasicBatchScheduler::Options.enable_large_batch_splitting` to true.
90 Status SplitSavedModelInputTask(
91  std::unique_ptr<SavedModelBatchingTask> *input_task_ptr,
92  int open_batch_remaining_slot, int max_batch_size,
93  std::vector<std::unique_ptr<SavedModelBatchingTask>> *output_tasks);
94 
95 } // namespace serving
96 } // namespace tensorflow
97 
98 #endif // TENSORFLOW_SERVING_BATCHING_TFRT_SAVED_MODEL_WITH_BATCHING_H_