16 #include "tensorflow_serving/util/json_tensor.h"
25 #include <type_traits>
26 #include <unordered_map>
28 #include "rapidjson/document.h"
29 #include "rapidjson/error/en.h"
30 #include "rapidjson/memorystream.h"
31 #include "rapidjson/prettywriter.h"
32 #include "rapidjson/rapidjson.h"
33 #include "rapidjson/reader.h"
34 #include "rapidjson/stringbuffer.h"
35 #include "absl/strings/escaping.h"
36 #include "absl/strings/match.h"
37 #include "absl/strings/numbers.h"
38 #include "absl/strings/str_format.h"
39 #include "absl/strings/str_join.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/example/feature.pb.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor_shape.pb.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow_serving/apis/input.pb.h"
49 #include "tensorflow_serving/apis/model.pb.h"
51 namespace tensorflow {
56 constexpr
char kPredictRequestSignatureKey[] =
"signature_name";
60 constexpr
char kPredictRequestInstancesKey[] =
"instances";
64 constexpr
char kPredictRequestInputsKey[] =
"inputs";
67 constexpr
char kClassifyRegressRequestContextKey[] =
"context";
70 constexpr
char kClassifyRegressRequestExamplesKey[] =
"examples";
74 constexpr
char kPredictResponsePredictionsKey[] =
"predictions";
78 constexpr
char kPredictResponseOutputsKey[] =
"outputs";
82 constexpr
char kClassifyRegressResponseKey[] =
"results";
85 constexpr
char kErrorResponseKey[] =
"error";
88 constexpr
char kBase64Key[] =
"b64";
91 constexpr
char kBytesTensorNameSuffix[] =
"_bytes";
93 using RapidJsonWriter = rapidjson::PrettyWriter<rapidjson::StringBuffer>;
95 string JsonTypeString(
const rapidjson::Value& val) {
96 switch (val.GetType()) {
97 case rapidjson::kNullType:
99 case rapidjson::kFalseType:
101 case rapidjson::kTrueType:
103 case rapidjson::kObjectType:
105 case rapidjson::kArrayType:
107 case rapidjson::kStringType:
109 case rapidjson::kNumberType:
114 template <
typename dtype>
115 bool StringToDecimal(
const absl::string_view s, dtype* out) {
116 return absl::SimpleAtof(s, out);
120 bool StringToDecimal(
const absl::string_view s,
double* out) {
121 return absl::SimpleAtod(s, out);
124 template <
typename dtype>
125 bool WriteDecimal(RapidJsonWriter* writer, dtype val) {
127 std::is_same<dtype, float>::value || std::is_same<dtype, double>::value,
128 "Only floating-point value types are supported.");
138 if (std::isfinite(val)) {
139 decimal_str = absl::StrCat(val);
143 if (!StringToDecimal(decimal_str, &num)) {
147 decimal_str = absl::StrFormat(
148 "%.*g", std::numeric_limits<dtype>::max_digits10, val);
159 if (!absl::StrContains(decimal_str,
'.') &&
160 !absl::StrContains(decimal_str,
'e')) {
161 absl::StrAppend(&decimal_str,
".0");
163 }
else if (std::isnan(val)) {
165 }
else if (std::isinf(val)) {
166 decimal_str = std::signbit(val) ?
"-Infinity" :
"Infinity";
168 return writer->RawValue(decimal_str.c_str(), decimal_str.size(),
169 rapidjson::kNumberType);
177 class JsonWriterWithLimit {
179 JsonWriterWithLimit(rapidjson::StringBuffer* buffer, RapidJsonWriter* writer,
181 : buffer_(buffer), writer_(writer), max_bytes_(max_bytes) {};
183 bool Null() {
return InLimit() ? writer_->Null() :
false; }
184 bool Bool(
bool b) {
return InLimit() ? writer_->Bool(b) :
false; }
185 bool Int(
int i) {
return InLimit() ? writer_->Int(i) :
false; }
186 bool Uint(
unsigned u) {
return InLimit() ? writer_->Uint(u) :
false; }
187 bool Int64(int64_t i64) {
return InLimit() ? writer_->Int64(i64) :
false; }
188 bool Uint64(uint64_t u64) {
return InLimit() ? writer_->Uint64(u64) :
false; }
189 bool Double(
double d) {
return InLimit() ? writer_->Double(d) :
false; }
191 bool RawNumber(
const rapidjson::MemoryStream::Ch* str,
192 rapidjson::SizeType length,
bool copy =
false) {
193 return InLimit() ? writer_->RawNumber(str, length, copy) :
false;
196 bool String(
const rapidjson::MemoryStream::Ch* str,
197 rapidjson::SizeType length,
bool copy =
false) {
198 return InLimit() ? writer_->String(str, length, copy) :
false;
201 bool StartObject() {
return InLimit() ? writer_->StartObject() :
false; }
203 bool Key(
const rapidjson::MemoryStream::Ch* str, rapidjson::SizeType length,
205 return InLimit() ? writer_->Key(str, length, copy) :
false;
208 bool EndObject(rapidjson::SizeType memberCount = 0) {
209 return InLimit() ? writer_->EndObject(memberCount) :
false;
212 bool StartArray() {
return InLimit() ? writer_->StartArray() :
false; }
214 bool EndArray(rapidjson::SizeType memberCount = 0) {
215 return InLimit() ? writer_->EndArray(memberCount) :
false;
219 bool InLimit() {
return buffer_->GetSize() < max_bytes_; }
220 const rapidjson::StringBuffer*
const buffer_;
221 RapidJsonWriter*
const writer_;
222 const int max_bytes_;
225 constexpr
int kMaxJsonDebugStringBytes = 256;
229 string JsonValueToDebugString(
const rapidjson::Value& val) {
230 rapidjson::StringBuffer buffer;
231 RapidJsonWriter writer(buffer);
232 writer.SetFormatOptions(rapidjson::kFormatSingleLineArray);
233 JsonWriterWithLimit j(&buffer, &writer, kMaxJsonDebugStringBytes);
235 return buffer.GetString();
239 string ShapeToString(
const TensorShapeProto& shape) {
240 return shape.unknown_rank()
246 [](
string* out,
const TensorShapeProto_Dim& dim) {
247 out->append(absl::StrCat(dim.size()));
252 bool IsShapeEqual(
const TensorShapeProto& lhs,
const TensorShapeProto& rhs) {
253 return !lhs.unknown_rank() && !rhs.unknown_rank() &&
254 lhs.dim_size() == rhs.dim_size() &&
255 std::equal(lhs.dim().begin(), lhs.dim().end(), rhs.dim().begin(),
256 [](
const TensorShapeProto_Dim& lhs,
257 const TensorShapeProto_Dim& rhs) {
258 return lhs.size() == rhs.size();
262 Status TypeError(
const rapidjson::Value& val, DataType dtype) {
263 return errors::InvalidArgument(
264 "JSON Value: ", JsonValueToDebugString(val),
265 " Type: ", JsonTypeString(val),
266 " is not of expected type: ", DataTypeString(dtype));
269 Status Base64FormatError(
const rapidjson::Value& val) {
270 return errors::InvalidArgument(
"JSON Value: ", JsonValueToDebugString(val),
271 " not formatted correctly for base64 data");
274 template <
typename... Args>
275 Status FormatError(
const rapidjson::Value& val, Args&&... args) {
276 return errors::InvalidArgument(
"JSON Value: ", JsonValueToDebugString(val),
277 " ", std::forward<Args>(args)...);
280 Status FormatSignatureError(
const rapidjson::Value& val) {
281 return errors::InvalidArgument(
282 "JSON Value: ", JsonValueToDebugString(val),
283 " not formatted correctly. 'signature_name' key must be a string value.");
290 Status AddValueToTensor(
const rapidjson::Value& val, DataType dtype,
291 TensorProto* tensor) {
294 if (!val.IsNumber())
return TypeError(val, dtype);
295 tensor->add_float_val(val.GetFloat());
299 if (!val.IsNumber())
return TypeError(val, dtype);
300 tensor->add_double_val(val.GetDouble());
307 if (!val.IsInt())
return TypeError(val, dtype);
308 tensor->add_int_val(val.GetInt());
312 if (!val.IsString())
return TypeError(val, dtype);
313 tensor->add_string_val(val.GetString(), val.GetStringLength());
317 if (!val.IsInt64())
return TypeError(val, dtype);
318 tensor->add_int64_val(val.GetInt64());
322 if (!val.IsBool())
return TypeError(val, dtype);
323 tensor->add_bool_val(val.GetBool());
327 if (!val.IsUint())
return TypeError(val, dtype);
328 tensor->add_uint32_val(val.GetUint());
332 if (!val.IsUint64())
return TypeError(val, dtype);
333 tensor->add_uint64_val(val.GetUint64());
337 return errors::Unimplemented(
338 "Conversion of JSON Value: ", JsonValueToDebugString(val),
339 " to type: ", DataTypeString(dtype));
349 void GetDenseTensorShape(
const rapidjson::Value& val, TensorShapeProto* shape) {
350 if (!val.IsArray())
return;
351 const auto size = val.Size();
352 shape->add_dim()->set_size(size);
354 GetDenseTensorShape(val[0], shape);
358 bool IsValBase64Object(
const rapidjson::Value& val) {
364 if (val.IsObject()) {
365 const auto itr = val.FindMember(kBase64Key);
366 if (itr != val.MemberEnd() && val.MemberCount() == 1 &&
367 itr->value.IsString()) {
377 Status JsonDecodeBase64Object(
const rapidjson::Value& val,
378 string* decoded_val) {
379 if (!IsValBase64Object(val)) {
380 return Base64FormatError(val);
382 const auto itr = val.FindMember(kBase64Key);
383 if (!absl::Base64Unescape(absl::string_view(itr->value.GetString(),
384 itr->value.GetStringLength()),
386 return errors::InvalidArgument(
"Unable to base64 decode");
392 Status FillTensorProto(
const rapidjson::Value& val,
int level, DataType dtype,
393 int* val_count, TensorProto* tensor) {
394 const auto rank = tensor->tensor_shape().dim_size();
395 if (!val.IsArray()) {
399 return errors::InvalidArgument(
400 "JSON Value: ", JsonValueToDebugString(val),
401 " found at incorrect level: ", level + 1,
402 " in the JSON DOM. Expected at level: ", rank);
405 if (val.IsObject()) {
406 status = (dtype == DT_STRING)
407 ? JsonDecodeBase64Object(val, tensor->add_string_val())
408 : TypeError(val, dtype);
410 status = AddValueToTensor(val, dtype, tensor);
412 if (status.ok()) (*val_count)++;
418 return errors::InvalidArgument(
419 "Encountered list at unexpected level: ", level,
" expected < ", rank);
423 if (val.Size() != tensor->tensor_shape().dim(level).size()) {
424 return errors::InvalidArgument(
425 "Encountered list at unexpected size: ", val.Size(),
426 " at level: ", level,
427 " expected size: ", tensor->tensor_shape().dim(level).size());
431 for (
const auto& v : val.GetArray()) {
432 TF_RETURN_IF_ERROR(FillTensorProto(v, level + 1, dtype, val_count, tensor));
443 Status AddInstanceItem(
const rapidjson::Value& item,
const string& name,
444 const ::google::protobuf::Map<string, TensorInfo>& tensorinfo_map,
445 ::google::protobuf::Map<string, int>* size_map,
446 ::google::protobuf::Map<string, TensorShapeProto>* shape_map,
447 ::google::protobuf::Map<string, TensorProto>* tensor_map) {
448 if (!tensorinfo_map.count(name)) {
449 return errors::InvalidArgument(
"JSON object: does not have named input: ",
453 const auto dtype = tensorinfo_map.at(name).dtype();
454 auto* tensor = &(*tensor_map)[name];
455 tensor->mutable_tensor_shape()->Clear();
456 GetDenseTensorShape(item, tensor->mutable_tensor_shape());
458 FillTensorProto(item, 0 , dtype, &size, tensor));
459 if (!size_map->count(name)) {
460 (*size_map)[name] = size;
461 (*shape_map)[name] = tensor->tensor_shape();
462 }
else if ((*size_map)[name] != size) {
463 return errors::InvalidArgument(
"Expecting tensor size: ", (*size_map)[name],
465 }
else if (!IsShapeEqual((*shape_map)[name], tensor->tensor_shape())) {
466 return errors::InvalidArgument(
467 "Expecting shape ", ShapeToString((*shape_map)[name]),
468 " but got: ", ShapeToString(tensor->tensor_shape()));
473 Status ParseJson(
const absl::string_view json, rapidjson::Document* doc) {
475 return errors::InvalidArgument(
"JSON Parse error: The document is empty");
480 rapidjson::MemoryStream ms(json.data(), json.size());
481 rapidjson::EncodedInputStream<rapidjson::UTF8<>, rapidjson::MemoryStream>
483 constexpr
auto parse_flags = rapidjson::kParseIterativeFlag |
484 rapidjson::kParseNanAndInfFlag |
485 rapidjson::kParseStopWhenDoneFlag;
486 if (doc->ParseStream<parse_flags>(jsonstream).HasParseError()) {
487 return errors::InvalidArgument(
488 "JSON Parse error: ", rapidjson::GetParseError_En(doc->GetParseError()),
489 " at offset: ", doc->GetErrorOffset());
498 if (!doc->IsObject()) {
499 return FormatError(*doc,
"Is not object");
504 template <
typename RequestTypeProto>
505 Status FillSignature(
const rapidjson::Document& doc,
506 RequestTypeProto* request) {
508 auto itr = doc.FindMember(kPredictRequestSignatureKey);
509 if (itr != doc.MemberEnd()) {
510 if (!itr->value.IsString()) {
511 return FormatSignatureError(doc);
513 request->mutable_model_spec()->set_signature_name(
514 itr->value.GetString(), itr->value.GetStringLength());
519 Status FillTensorMapFromInstancesList(
520 const rapidjson::Value::MemberIterator& itr,
521 const ::google::protobuf::Map<string, tensorflow::TensorInfo>& tensorinfo_map,
522 ::google::protobuf::Map<string, TensorProto>* tensor_map) {
526 if (!itr->value[0].IsObject() && tensorinfo_map.size() > 1) {
527 return errors::InvalidArgument(
528 "instances is a plain list, but expecting list of objects as multiple "
529 "input tensors required as per tensorinfo_map");
532 auto IsElementObject = [](
const rapidjson::Value& val) {
533 return val.IsObject() && !IsValBase64Object(val);
536 const bool elements_are_objects = IsElementObject(itr->value[0]);
538 std::set<string> input_names;
539 for (
const auto& kv : tensorinfo_map) input_names.insert(kv.first);
547 ::google::protobuf::Map<string, int> size_map;
548 ::google::protobuf::Map<string, TensorShapeProto> shape_map;
549 int tensor_count = 0;
550 for (
const auto& elem : itr->value.GetArray()) {
551 if (elements_are_objects) {
552 if (!IsElementObject(elem)) {
553 return errors::InvalidArgument(
"Expecting object but got list at item ",
554 tensor_count,
" of input list");
556 std::set<string> object_keys;
557 for (
const auto& kv : elem.GetObject()) {
558 const string& name = kv.name.GetString();
559 object_keys.insert(name);
560 const auto status = AddInstanceItem(kv.value, name, tensorinfo_map,
561 &size_map, &shape_map, tensor_map);
563 return errors::InvalidArgument(
564 "Failed to process element: ", tensor_count,
" key: ", name,
565 " of 'instances' list. Error: ", status.ToString());
568 if (input_names != object_keys) {
569 return errors::InvalidArgument(
570 "Failed to process element: ", tensor_count,
571 " of 'instances' list. JSON object: ", JsonValueToDebugString(elem),
572 " keys must be equal to: ", absl::StrJoin(input_names,
","));
575 if (IsElementObject(elem)) {
576 return errors::InvalidArgument(
577 "Expecting value/list but got object at item ", tensor_count,
580 const auto& name = tensorinfo_map.begin()->first;
581 const auto status = AddInstanceItem(elem, name, tensorinfo_map, &size_map,
582 &shape_map, tensor_map);
584 return errors::InvalidArgument(
585 "Failed to process element: ", tensor_count,
586 " of 'instances' list. Error: ", status.ToString());
594 for (
auto& kv : *tensor_map) {
595 const string& name = kv.first;
596 auto* tensor = &kv.second;
597 tensor->set_dtype(tensorinfo_map.at(name).dtype());
598 const auto& shape = shape_map.at(name);
599 auto* output_shape = tensor->mutable_tensor_shape();
600 output_shape->Clear();
601 output_shape->add_dim()->set_size(tensor_count);
602 for (
const auto& d : shape.dim())
603 output_shape->add_dim()->set_size(d.size());
609 Status FillTensorMapFromInputsMap(
610 const rapidjson::Value::MemberIterator& itr,
611 const ::google::protobuf::Map<string, tensorflow::TensorInfo>& tensorinfo_map,
612 ::google::protobuf::Map<string, TensorProto>* tensor_map) {
616 const rapidjson::Value& val = itr->value;
617 if (!val.IsObject() || IsValBase64Object(val)) {
618 if (tensorinfo_map.size() > 1) {
619 return errors::InvalidArgument(
620 "inputs is a plain value/list, but expecting an object as multiple "
621 "input tensors required as per tensorinfo_map");
624 auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first];
625 tensor->set_dtype(tensorinfo_map.begin()->second.dtype());
626 GetDenseTensorShape(val, tensor->mutable_tensor_shape());
628 TF_RETURN_IF_ERROR(FillTensorProto(val, 0 , tensor->dtype(),
629 &unused_size, tensor));
631 for (
const auto& kv : tensorinfo_map) {
632 const auto& name = kv.first;
633 auto item = val.FindMember(name.c_str());
634 if (item == val.MemberEnd()) {
635 return errors::InvalidArgument(
"Missing named input: ", name,
636 " in 'inputs' object.");
638 const auto dtype = kv.second.dtype();
639 auto* tensor = &(*tensor_map)[name];
640 tensor->set_dtype(dtype);
641 tensor->mutable_tensor_shape()->Clear();
642 GetDenseTensorShape(item->value, tensor->mutable_tensor_shape());
644 TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 , dtype,
645 &unused_size, tensor));
653 Status FillPredictRequestFromJson(
654 const absl::string_view json,
655 const std::function<tensorflow::Status(
656 const string&, ::google::protobuf::Map<string, tensorflow::TensorInfo>*)>&
658 PredictRequest* request, JsonPredictRequestFormat* format) {
659 rapidjson::Document doc;
660 *format = JsonPredictRequestFormat::kInvalid;
661 TF_RETURN_IF_ERROR(ParseJson(json, &doc));
662 TF_RETURN_IF_ERROR(FillSignature(doc, request));
664 ::google::protobuf::Map<string, tensorflow::TensorInfo> tensorinfo_map;
665 const string& signame = request->model_spec().signature_name();
666 TF_RETURN_IF_ERROR(get_tensorinfo_map(signame, &tensorinfo_map));
667 if (tensorinfo_map.empty()) {
668 return errors::InvalidArgument(
"Failed to get input map for signature: ",
669 signame.empty() ?
"DEFAULT" : signame);
675 auto itr_instances = doc.FindMember(kPredictRequestInstancesKey);
676 auto itr_inputs = doc.FindMember(kPredictRequestInputsKey);
677 if (itr_instances != doc.MemberEnd()) {
678 if (itr_inputs != doc.MemberEnd()) {
679 return FormatError(doc,
"Not formatted correctly expecting only",
680 " one of '", kPredictRequestInputsKey,
"' or '",
681 kPredictRequestInstancesKey,
"' keys to exist ");
683 if (!itr_instances->value.IsArray()) {
684 return FormatError(doc,
"Expecting '",
685 kPredictRequestInstancesKey,
"' to be an list/array");
687 if (!itr_instances->value.Capacity()) {
688 return FormatError(doc,
"No values in '",
689 kPredictRequestInstancesKey,
"' array");
691 *format = JsonPredictRequestFormat::kRow;
692 return FillTensorMapFromInstancesList(itr_instances, tensorinfo_map,
693 request->mutable_inputs());
694 }
else if (itr_inputs != doc.MemberEnd()) {
695 if (itr_instances != doc.MemberEnd()) {
696 return FormatError(doc,
"Not formatted correctly expecting only",
697 " one of '", kPredictRequestInputsKey,
"' or '",
698 kPredictRequestInstancesKey,
"' keys to exist ");
700 *format = JsonPredictRequestFormat::kColumnar;
701 return FillTensorMapFromInputsMap(itr_inputs, tensorinfo_map,
702 request->mutable_inputs());
704 return errors::InvalidArgument(
"Missing 'inputs' or 'instances' key");
709 bool IsFeatureOfKind(
const Feature& feature, Feature::KindCase kind) {
710 return feature.kind_case() == Feature::KIND_NOT_SET ||
711 feature.kind_case() == kind;
714 Status IncompatibleFeatureKindError(
const string& feature_name,
715 const Feature& feature) {
717 switch (feature.kind_case()) {
718 case Feature::KindCase::kBytesList:
721 case Feature::KindCase::kFloatList:
724 case Feature::KindCase::kInt64List:
727 case Feature::KindCase::KIND_NOT_SET:
728 kind_str =
"UNKNOWN";
731 return errors::InvalidArgument(
"Unexpected element type in feature: ",
732 feature_name,
" expecting type: ", kind_str);
737 Status AddValueToFeature(
const rapidjson::Value& val,
738 const string& feature_name, Feature* feature) {
739 switch (val.GetType()) {
740 case rapidjson::kNullType:
741 return errors::InvalidArgument(
742 "Feature: ", feature_name,
743 " has element with unexpected JSON type: ", JsonTypeString(val));
744 case rapidjson::kFalseType:
745 case rapidjson::kTrueType:
746 if (!IsFeatureOfKind(*feature, Feature::KindCase::kInt64List)) {
747 return IncompatibleFeatureKindError(feature_name, *feature);
749 feature->mutable_int64_list()->add_value(val.GetBool() ? 1 : 0);
751 case rapidjson::kObjectType:
752 if (!IsValBase64Object(val)) {
753 return errors::InvalidArgument(
754 "Feature: ", feature_name,
755 " has element with unexpected JSON type: ", JsonTypeString(val));
757 if (!IsFeatureOfKind(*feature, Feature::KindCase::kBytesList)) {
758 return IncompatibleFeatureKindError(feature_name, *feature);
760 TF_RETURN_IF_ERROR(JsonDecodeBase64Object(
761 val, feature->mutable_bytes_list()->add_value()));
763 case rapidjson::kArrayType:
764 return errors::InvalidArgument(
765 "Feature: ", feature_name,
766 " has element with unexpected JSON type: ", JsonTypeString(val));
767 case rapidjson::kStringType:
768 if (!IsFeatureOfKind(*feature, Feature::KindCase::kBytesList)) {
769 return IncompatibleFeatureKindError(feature_name, *feature);
771 feature->mutable_bytes_list()->add_value(val.GetString(),
772 val.GetStringLength());
774 case rapidjson::kNumberType:
775 if (val.IsDouble()) {
776 if (!IsFeatureOfKind(*feature, Feature::KindCase::kFloatList)) {
777 return IncompatibleFeatureKindError(feature_name, *feature);
779 feature->mutable_float_list()->add_value(val.GetFloat());
781 if (!IsFeatureOfKind(*feature, Feature::KindCase::kInt64List)) {
782 return IncompatibleFeatureKindError(feature_name, *feature);
784 if (!val.IsInt64() && val.IsUint64()) {
785 return errors::InvalidArgument(
786 "Feature: ", feature_name,
787 " has uint64_t element. Only int64_t is supported.");
789 feature->mutable_int64_list()->add_value(val.GetInt64());
796 Status MakeExampleFromJsonObject(
const rapidjson::Value& val,
798 if (!val.IsObject()) {
799 return errors::InvalidArgument(
"Example must be JSON object but got JSON ",
800 JsonTypeString(val));
802 for (
const auto& kv : val.GetObject()) {
803 const string& name = kv.name.GetString();
804 const auto& content = kv.value;
806 if (content.IsArray()) {
807 for (
const auto& val : content.GetArray()) {
808 TF_RETURN_IF_ERROR(AddValueToFeature(val, name, &feature));
811 TF_RETURN_IF_ERROR(AddValueToFeature(content, name, &feature));
813 (*example->mutable_features()->mutable_feature())[name] = feature;
818 template <
typename RequestProto>
819 Status FillClassifyRegressRequestFromJson(
const absl::string_view json,
820 RequestProto* request) {
821 rapidjson::Document doc;
822 TF_RETURN_IF_ERROR(ParseJson(json, &doc));
823 TF_RETURN_IF_ERROR(FillSignature(doc, request));
826 bool has_context =
false;
827 auto*
const input = request->mutable_input();
828 auto itr = doc.FindMember(kClassifyRegressRequestContextKey);
829 if (itr != doc.MemberEnd()) {
830 TF_RETURN_IF_ERROR(MakeExampleFromJsonObject(
832 input->mutable_example_list_with_context()->mutable_context()));
837 itr = doc.FindMember(kClassifyRegressRequestExamplesKey);
838 if (itr == doc.MemberEnd()) {
839 return FormatError(doc,
"When method is classify, key '",
840 kClassifyRegressRequestExamplesKey,
841 "' is expected and was not found");
843 if (!itr->value.IsArray()) {
844 return FormatError(doc,
"Expecting '",
845 kClassifyRegressRequestExamplesKey,
846 "' value to be an list/array");
848 if (!itr->value.Capacity()) {
849 return FormatError(doc,
"'", kClassifyRegressRequestExamplesKey,
850 "' value is an empty array");
852 for (
const auto& val : itr->value.GetArray()) {
853 TF_RETURN_IF_ERROR(MakeExampleFromJsonObject(
855 ? input->mutable_example_list_with_context()->add_examples()
856 : input->mutable_example_list()->add_examples()));
864 Status FillClassificationRequestFromJson(
const absl::string_view json,
865 ClassificationRequest* request) {
866 return FillClassifyRegressRequestFromJson(json, request);
869 Status FillRegressionRequestFromJson(
const absl::string_view json,
870 RegressionRequest* request) {
871 return FillClassifyRegressRequestFromJson(json, request);
877 bool IsNamedTensorBytes(
const string& name,
const TensorProto& tensor) {
883 return tensor.dtype() == DT_STRING &&
884 absl::EndsWith(name, kBytesTensorNameSuffix);
888 Status AddSingleValueAndAdvance(
const TensorProto& tensor,
bool string_as_bytes,
889 RapidJsonWriter* writer,
int* offset) {
890 bool success =
false;
891 switch (tensor.dtype()) {
893 success = WriteDecimal(writer, tensor.float_val(*offset));
897 success = WriteDecimal(writer, tensor.double_val(*offset));
904 success = writer->Int(tensor.int_val(*offset));
908 const string& str = tensor.string_val(*offset);
909 if (string_as_bytes) {
912 absl::Base64Escape(str, &base64);
913 writer->StartObject();
914 writer->Key(kBase64Key);
915 success = writer->String(base64.c_str(), base64.size());
918 success = writer->String(str.c_str(), str.size());
924 success = writer->Int64(tensor.int64_val(*offset));
928 success = writer->Bool(tensor.bool_val(*offset));
932 success = writer->Uint(tensor.uint32_val(*offset));
936 success = writer->Uint64(tensor.uint64_val(*offset));
944 return errors::InvalidArgument(
945 "Failed to write JSON value for tensor type: ",
946 DataTypeString(tensor.dtype()));
952 Status AddTensorValues(
const TensorProto& tensor,
bool string_as_bytes,
int dim,
953 RapidJsonWriter* writer,
int* offset) {
956 if (dim > tensor.tensor_shape().dim_size() - 1) {
957 return AddSingleValueAndAdvance(tensor, string_as_bytes, writer, offset);
959 writer->StartArray();
960 if (dim == tensor.tensor_shape().dim_size() - 1) {
961 for (
int i = 0; i < tensor.tensor_shape().dim(dim).size(); i++) {
963 AddSingleValueAndAdvance(tensor, string_as_bytes, writer, offset));
966 for (
int i = 0; i < tensor.tensor_shape().dim(dim).size(); i++) {
968 AddTensorValues(tensor, string_as_bytes, dim + 1, writer, offset));
975 Status MakeRowFormatJsonFromTensors(
976 const ::google::protobuf::Map<string, TensorProto>& tensor_map,
string* json) {
982 std::unordered_map<string, int> offset_map;
983 for (
const auto& kv : tensor_map) {
984 const auto& name = kv.first;
985 const auto& tensor = kv.second;
986 if (tensor.tensor_shape().dim_size() == 0) {
987 return errors::InvalidArgument(
"Tensor name: ", name,
988 " has no shape information ");
990 const int cur_batch_size = tensor.tensor_shape().dim(0).size();
991 if (batch_size >= 0 && batch_size != cur_batch_size) {
992 return errors::InvalidArgument(
993 "Tensor name: ", name,
994 " has inconsistent batch size: ", cur_batch_size,
995 " expecting: ", batch_size);
997 batch_size = cur_batch_size;
998 offset_map.insert({name, 0});
1001 rapidjson::StringBuffer buffer;
1002 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
1003 writer.StartObject();
1004 writer.Key(kPredictResponsePredictionsKey);
1005 writer.StartArray();
1006 const bool elements_are_objects = tensor_map.size() > 1;
1007 for (
int item = 0; item < batch_size; item++) {
1008 if (elements_are_objects) writer.StartObject();
1009 writer.SetFormatOptions(rapidjson::kFormatSingleLineArray);
1010 for (
const auto& kv : tensor_map) {
1011 const auto& name = kv.first;
1012 const auto& tensor = kv.second;
1013 if (elements_are_objects) writer.Key(name.c_str());
1014 TF_RETURN_IF_ERROR(AddTensorValues(
1015 tensor, IsNamedTensorBytes(name, tensor),
1017 &writer, &offset_map.at(name)));
1019 writer.SetFormatOptions(rapidjson::kFormatDefault);
1020 if (elements_are_objects) writer.EndObject();
1024 json->assign(buffer.GetString());
1028 Status MakeColumnarFormatJsonFromTensors(
1029 const ::google::protobuf::Map<string, TensorProto>& tensor_map,
string* json) {
1030 rapidjson::StringBuffer buffer;
1031 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
1032 writer.StartObject();
1033 writer.Key(kPredictResponseOutputsKey);
1034 const bool elements_are_objects = tensor_map.size() > 1;
1035 if (elements_are_objects) writer.StartObject();
1036 for (
const auto& kv : tensor_map) {
1037 const auto& name = kv.first;
1038 const auto& tensor = kv.second;
1039 if (elements_are_objects) writer.Key(name.c_str());
1040 int unused_offset = 0;
1041 TF_RETURN_IF_ERROR(AddTensorValues(tensor, IsNamedTensorBytes(name, tensor),
1042 0, &writer, &unused_offset));
1044 if (elements_are_objects) writer.EndObject();
1046 json->assign(buffer.GetString());
1052 Status MakeJsonFromTensors(const ::google::protobuf::Map<string, TensorProto>& tensor_map,
1053 JsonPredictRequestFormat format,
string* json) {
1054 if (tensor_map.empty()) {
1055 return errors::InvalidArgument(
"Cannot convert empty tensor map to JSON");
1059 case JsonPredictRequestFormat::kInvalid:
1060 return errors::InvalidArgument(
"Invalid request format");
1061 case JsonPredictRequestFormat::kRow:
1062 return MakeRowFormatJsonFromTensors(tensor_map, json);
1063 case JsonPredictRequestFormat::kColumnar:
1064 return MakeColumnarFormatJsonFromTensors(tensor_map, json);
1068 Status MakeJsonFromClassificationResult(
const ClassificationResult& result,
1070 if (result.classifications_size() == 0) {
1071 return errors::InvalidArgument(
1072 "Cannot convert empty ClassificationResults to JSON");
1075 rapidjson::StringBuffer buffer;
1076 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
1077 writer.StartObject();
1078 writer.Key(kClassifyRegressResponseKey);
1079 writer.StartArray();
1080 for (
const auto& classifications : result.classifications()) {
1081 writer.SetFormatOptions(rapidjson::kFormatSingleLineArray);
1082 writer.StartArray();
1083 for (
const auto& elem : classifications.classes()) {
1084 writer.StartArray();
1085 if (!writer.String(elem.label().c_str(), elem.label().size())) {
1086 return errors::Internal(
"Failed to write class label: ", elem.label(),
1087 " to output JSON buffer");
1089 if (!WriteDecimal(&writer, elem.score())) {
1090 return errors::Internal(
"Failed to write class score : ", elem.score(),
1091 " to output JSON buffer");
1096 writer.SetFormatOptions(rapidjson::kFormatDefault);
1100 json->assign(buffer.GetString());
1104 Status MakeJsonFromRegressionResult(
const RegressionResult& result,
1106 if (result.regressions_size() == 0) {
1107 return errors::InvalidArgument(
1108 "Cannot convert empty RegressionResults to JSON");
1111 rapidjson::StringBuffer buffer;
1112 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
1113 writer.StartObject();
1114 writer.Key(kClassifyRegressResponseKey);
1115 writer.SetFormatOptions(rapidjson::kFormatSingleLineArray);
1116 writer.StartArray();
1117 for (
const auto& regression : result.regressions()) {
1118 if (!WriteDecimal(&writer, regression.value())) {
1119 return errors::Internal(
"Failed to write regression value : ",
1120 regression.value(),
" to output JSON buffer");
1125 json->assign(buffer.GetString());
1129 void MakeJsonFromStatus(
const tensorflow::Status& status,
string* json) {
1130 if (status.ok())
return;
1131 absl::string_view error_message = status.message();
1132 rapidjson::StringBuffer buffer;
1133 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
1134 writer.StartObject();
1135 writer.Key(kErrorResponseKey);
1136 writer.String(error_message.data(), error_message.size());
1138 json->append(buffer.GetString());