From 7d522a1cdd48d2d7cee2729870089329292eb1be Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Thu, 7 Jan 2021 20:09:39 +0800 Subject: [PATCH] dts:caffe parser support parse repeated message --- parser/caffe/caffe_parser.cc | 22 ++++++++++++++++------ parser/common/convert/pb2json.h | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 375c1c4..f94b071 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -193,6 +193,7 @@ const int kMaxParseDepth = 5; const int32_t kMinLineWorldSize = 3; const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 const int32_t kBase = 10; +const uint32_t kInteval = 2; const char *const kPython = "Python"; const char *const kProposalLayer = "ProposalLayer"; const char *const kDetectionOutput = "DetectionOutput"; @@ -793,13 +794,22 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection * CASE_FIELD_TYPE_REPEATED(STRING, String, string); #undef CASE_FIELD_TYPE_REPEATED case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - for (int i = 0; i < field_size; ++i) { - const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); - if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { - GELOGE(FAILED, "ParseOperatorAttrs of field: %s failed.", field->name().c_str()); - return FAILED; - } + nlohmann::json message_json; + Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), + message_json[field->name()], false); + std::string repeated_message_str; + try { + repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); + } catch (std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); + GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); + return FAILED; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E19008"); + GELOGE(FAILED, "Failed to convert JSON to string."); + return FAILED; } + (void)ops.SetAttr(field->name(), repeated_message_str); break; } default: { diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h index 7bc55b1..4f8e406 100644 --- a/parser/common/convert/pb2json.h +++ b/parser/common/convert/pb2json.h @@ -47,11 +47,11 @@ class Pb2Json { static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, bool enum2str = false); - protected: static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const std::set &black_fields, Json &json, bool enum2str); + protected: static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, bool enum2str, Json &json);