16 #ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_SESSION_H_
17 #define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFLITE_SESSION_H_
23 #include "absl/base/thread_annotations.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/file_system.h"
29 #include "tensorflow/core/platform/threadpool.h"
30 #include "tensorflow/core/platform/threadpool_options.h"
31 #include "tensorflow/core/protobuf/meta_graph.pb.h"
32 #include "tensorflow/lite/external_cpu_backend_context.h"
33 #include "tensorflow/lite/interpreter.h"
34 #include "tensorflow/lite/kernels/cpu_backend_context.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow_serving/batching/threadsafe_status.h"
37 #include "tensorflow_serving/servables/tensorflow/serving_session.h"
38 #include "tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.h"
40 namespace tensorflow {
43 using TensorInfoMap = std::map<string, std::pair<TensorInfo, int>>;
49 static void CreateTfLiteBatchTask(
50 const std::vector<string>* output_tensor_names,
51 std::vector<Tensor>* outputs, Notification* done, Status* status,
52 std::unique_ptr<TfLiteBatchTask>* batch_task) {
54 task->is_partial =
false;
55 task->output_tensor_names = output_tensor_names;
56 task->outputs = outputs;
58 task->status = status;
59 batch_task->reset(task);
63 static void CreatePartialTfLiteBatchTask(
64 std::vector<int> input_indices,
65 const std::vector<string>* output_tensor_names,
66 std::vector<Tensor>* outputs, std::function<
void()> done_callback,
68 std::unique_ptr<TfLiteBatchTask>* batch_task) {
70 task->is_partial =
true;
71 task->input_indices = input_indices;
72 task->output_tensor_names = output_tensor_names;
73 task->outputs = outputs;
74 task->done_callback = done_callback;
75 task->partial_status = partial_status;
76 batch_task->reset(task);
88 size_t size()
const override {
return inputs[0].dim_size(0); }
90 uint64_t start_time_micros()
const {
return enqueue_time_micros; }
97 std::vector<int> input_indices;
100 std::vector<Tensor> inputs;
103 std::vector<Tensor>* outputs;
105 void set_output(Tensor t) { outputs->push_back(t); }
107 const std::vector<string>* output_tensor_names;
109 RunOptions run_options;
111 const uint64_t enqueue_time_micros;
114 bool is_partial =
false;
117 std::function<void()> done_callback;
122 using SchedulerCreator = std::function<Status(
123 const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
124 std::function<
void(std::unique_ptr<Batch<TfLiteBatchTask>>)>,
125 std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>*)>;
138 static Status Create(
string&& buffer,
const SessionOptions& options,
139 int num_pools,
int num_interpreters_per_pool,
140 std::unique_ptr<TfLiteSession>* tflite_session,
141 ::google::protobuf::Map<string, SignatureDef>* signatures);
143 static Status CreateDefaultBasicBatchScheduler(
144 const BasicBatchScheduler<TfLiteBatchTask>::Options& options,
145 std::function<
void(std::unique_ptr<Batch<TfLiteBatchTask>>)>
146 process_batch_callback,
147 std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>>* batch_scheduler);
149 static Status SplitTfLiteInputTask(
150 std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
151 int open_batch_remaining_slot,
int max_batch_size,
152 std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks);
156 Status Run(
const std::vector<std::pair<string, Tensor>>& inputs,
157 const std::vector<string>& output_tensor_names,
158 const std::vector<string>& target_node_names,
159 std::vector<Tensor>* outputs)
override;
161 Status Run(
const RunOptions& run_options,
162 const std::vector<std::pair<string, Tensor>>& inputs,
163 const std::vector<string>& output_tensor_names,
164 const std::vector<string>& target_node_names,
165 std::vector<Tensor>* outputs, RunMetadata* run_metadata)
override;
167 Status Run(
const RunOptions& run_options,
168 const std::vector<std::pair<string, Tensor>>& inputs,
169 const std::vector<string>& output_tensor_names,
170 const std::vector<string>& target_node_names,
171 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
172 const thread::ThreadPoolOptions& thread_pool_options)
override;
174 Status ListDevices(std::vector<DeviceAttributes>* response)
override;
177 const SchedulerCreator& scheduler_creator,
178 const BasicBatchScheduler<TfLiteBatchTask>::Options& options);
180 BasicBatchScheduler<TfLiteBatchTask>::Options GetSchedulerOptions() {
181 return scheduler_options_;
186 std::map<string, int>&& input_tensor_to_index,
187 std::map<string, int>&& output_tensor_to_index,
string&& buffer,
188 std::unique_ptr<tflite::FlatBufferModel> model,
189 std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool);
191 const std::vector<int>& tflite_input_indices,
192 const std::vector<std::vector<const Tensor*>>& merged_inputs,
193 const std::vector<string>& output_tensor_names,
194 std::vector<Tensor>* combined_outputs,
int batch_size,
195 int* fixed_batch_size =
nullptr);
196 const std::map<string, int> input_tensor_to_index_;
197 const std::map<string, int> output_tensor_to_index_;
198 const string model_serialized_bytes_;
199 const std::unique_ptr<tflite::FlatBufferModel> model_;
200 const std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool_;
201 bool use_fixed_batch_size_;
202 std::unique_ptr<BasicBatchScheduler<TfLiteBatchTask>> scheduler_;
203 BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options_;
204 void ProcessBatch(std::unique_ptr<Batch<TfLiteBatchTask>> batch);