16 #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/cc/saved_model/loader.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/threadpool_options.h"
25 #include "tensorflow_serving/core/servable_handle.h"
26 #include "tensorflow_serving/servables/tensorflow/predict_util.h"
27 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
28 #include "tensorflow_serving/servables/tensorflow/util.h"
30 namespace tensorflow {
33 Status TensorflowPredictor::Predict(
const RunOptions& run_options,
35 const PredictRequest& request,
36 PredictResponse* response) {
37 if (!request.has_model_spec()) {
38 return tensorflow::Status(
39 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
42 return PredictWithModelSpec(run_options, core, request.model_spec(), request,
46 Status TensorflowPredictor::PredictWithModelSpec(
const RunOptions& run_options,
48 const ModelSpec& model_spec,
49 const PredictRequest& request,
50 PredictResponse* response) {
51 ServableHandle<SavedModelBundle> bundle;
52 TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &bundle));
53 return internal::RunPredict(
54 run_options, bundle->meta_graph_def, bundle.id().version,
55 core->predict_response_tensor_serialization_option(),
56 bundle->session.get(), request, response,
57 thread_pool_factory_ ==
nullptr
58 ? thread::ThreadPoolOptions()
59 : thread_pool_factory_->GetThreadPools().get());