16 #include "tensorflow_serving/servables/tensorflow/tfrt_classification_service.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/platform/tracing.h"
22 #include "tensorflow_serving/apis/classifier.h"
23 #include "tensorflow_serving/core/servable_handle.h"
24 #include "tensorflow_serving/model_servers/server_core.h"
25 #include "tensorflow_serving/servables/tensorflow/servable.h"
26 #include "tensorflow_serving/servables/tensorflow/tfrt_classifier.h"
27 #include "tensorflow_serving/servables/tensorflow/util.h"
29 namespace tensorflow {
32 Status TFRTClassificationServiceImpl::Classify(
33 const Servable::RunOptions& run_options, ServerCore* core,
34 const ClassificationRequest& request, ClassificationResponse* response) {
36 if (!request.has_model_spec()) {
37 return tensorflow::Status(absl::StatusCode::kInvalidArgument,
41 return ClassifyWithModelSpec(run_options, core, request.model_spec(), request,
45 Status TFRTClassificationServiceImpl::ClassifyWithModelSpec(
46 const Servable::RunOptions& run_options, ServerCore* core,
47 const ModelSpec& model_spec,
const ClassificationRequest& request,
48 ClassificationResponse* response) {
49 TRACELITERAL(
"TFRTClassificationServiceImpl::ClassifyWithModelSpec");
51 ServableHandle<Servable> servable;
52 TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &servable));
53 return servable->Classify(run_options, request, response);