16 #include "tensorflow_serving/servables/tensorflow/classification_service.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/platform/tracing.h"
22 #include "tensorflow_serving/apis/classifier.h"
23 #include "tensorflow_serving/core/servable_handle.h"
24 #include "tensorflow_serving/servables/tensorflow/classifier.h"
25 #include "tensorflow_serving/servables/tensorflow/util.h"
27 namespace tensorflow {
30 Status TensorflowClassificationServiceImpl::Classify(
31 const RunOptions& run_options, ServerCore* core,
32 const thread::ThreadPoolOptions& thread_pool_options,
33 const ClassificationRequest& request, ClassificationResponse* response) {
35 if (!request.has_model_spec()) {
36 return tensorflow::Status(
37 static_cast<absl::StatusCode
>(absl::StatusCode::kInvalidArgument),
41 return ClassifyWithModelSpec(run_options, core, thread_pool_options,
42 request.model_spec(), request, response);
45 Status TensorflowClassificationServiceImpl::ClassifyWithModelSpec(
46 const RunOptions& run_options, ServerCore* core,
47 const thread::ThreadPoolOptions& thread_pool_options,
48 const ModelSpec& model_spec,
const ClassificationRequest& request,
49 ClassificationResponse* response) {
50 TRACELITERAL(
"TensorflowClassificationServiceImpl::ClassifyWithModelSpec");
52 ServableHandle<SavedModelBundle> saved_model_bundle;
53 TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &saved_model_bundle));
54 return RunClassify(run_options, saved_model_bundle->meta_graph_def,
55 saved_model_bundle.id().version,
56 saved_model_bundle->session.get(), request, response,