18 #ifndef TENSORFLOW_SERVING_BATCHING_TFRT_SAVED_MODEL_WITH_BATCHING_H_
19 #define TENSORFLOW_SERVING_BATCHING_TFRT_SAVED_MODEL_WITH_BATCHING_H_
25 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
26 #include "tensorflow_serving/batching/batching_options.h"
27 #include "tensorflow_serving/batching/batching_session.h"
29 namespace tensorflow {
35 struct SavedModelBatchingTask;
39 using SavedModelBatchingSchedulerCreator = std::function<Status(
40 std::function<
void(std::unique_ptr<Batch<SavedModelBatchingTask>>)>,
41 std::unique_ptr<BatchScheduler<SavedModelBatchingTask>> *)>;
46 absl::string_view func_name;
47 SavedModelBatchingSchedulerCreator scheduler_creator;
61 Status CreateSavedModelWithBatching(
63 const std::vector<FuncNameWithBatchingSchedulerCreator>
64 &func_name_with_batching_scheduler_creator,
65 std::unique_ptr<tfrt::SavedModel> saved_model,
66 std::unique_ptr<tfrt::SavedModel> *saved_model_with_batching);
70 static std::string Name() {
return "tfrt_saved_model_with_batching"; }
72 tfrt::HostContext *host_context;
73 absl::Span<const Tensor> tfrt_inputs;
74 std::vector<Tensor> *tfrt_outputs;
75 tfrt::SavedModel::RunOptions run_options;
79 std::vector<Tensor> tfrt_partial_inputs;
90 Status SplitSavedModelInputTask(
91 std::unique_ptr<SavedModelBatchingTask> *input_task_ptr,
92 int open_batch_remaining_slot,
int max_batch_size,
93 std::vector<std::unique_ptr<SavedModelBatchingTask>> *output_tasks);