From 759b5e419a611bd4f877c24304a7828a08dd9d16 Mon Sep 17 00:00:00 2001 From: wjm Date: Tue, 8 Jun 2021 21:37:56 +0800 Subject: [PATCH] fix --- parser/caffe/caffe_parser.cc | 143 +---------------------------------- parser/caffe/caffe_parser.h | 40 ---------- parser/onnx/onnx_parser.cc | 2 - 3 files changed, 2 insertions(+), 183 deletions(-) diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 9ea40e5..157f33f 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -21,6 +21,7 @@ #include #include #include +#include "common/convert/message2operator.h" #include "parser/common/convert/pb2json.h" #include "parser/common/acl_graph_parser_util.h" #include "common/op_map.h" @@ -578,7 +579,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co return FAILED; } - if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) { + if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) { GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str()); return FAILED; } @@ -589,146 +590,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co return SUCCESS; } -Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { - if (depth > kMaxParseDepth) { - REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); - GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); - return FAILED; - } - - const google::protobuf::Reflection *reflection = message->GetReflection(); - GE_CHECK_NOTNULL(reflection); - vector field_desc; - reflection->ListFields(*message, &field_desc); - - for (auto &field : field_desc) { - GE_CHECK_NOTNULL(field); - if (field->is_repeated()) { - if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); - return FAILED; - } - } else { - if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); - return FAILED; - } - } - } - return SUCCESS; -} - -Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, - int depth, ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - valuetype value = reflection->Get##method(*message, field); \ - GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ - (void)ops.SetAttr(field->name(), value); \ - break; \ - } - CASE_FIELD_TYPE(INT32, Int32, int32_t, d); - CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); - CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); - CASE_FIELD_TYPE(FLOAT, Float, float, f); - CASE_FIELD_TYPE(BOOL, Bool, bool, d); -#undef CASE_FIELD_TYPE - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); - int value = reflection->GetEnum(*message, field)->number(); - GELOGD("Parse result(%s : %d)", field->name().c_str(), value); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - string value = reflection->GetString(*message, field); - GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); - if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); - return FAILED; - } - break; - } - default: { - REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), - std::vector({"model", field->name(), "Unsupported field type"})); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse field: %s success.", field->name().c_str()); - return SUCCESS; -} - -Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, - ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - int field_size = reflection->FieldSize(*message, field); - if (field_size <= 0) { - REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); - GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); - return FAILED; - } - - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - vector attr_value; \ - for (int i = 0; i < field_size; i++) { \ - valuetype value = reflection->GetRepeated##method(*message, field, i); \ - attr_value.push_back(value); \ - } \ - (void)ops.SetAttr(field->name(), attr_value); \ - break; \ - } - CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); - CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); - CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); - CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); - CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); - CASE_FIELD_TYPE_REPEATED(STRING, String, string); -#undef CASE_FIELD_TYPE_REPEATED - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - nlohmann::json message_json; - Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), - message_json[field->name()], false); - std::string repeated_message_str; - try { - repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); - } catch (std::exception &e) { - REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); - return FAILED; - } catch (...) { - REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); - return FAILED; - } - (void)ops.SetAttr(field->name(), repeated_message_str); - break; - } - default: { - REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), - std::vector({"model", field->name(), "Unsupported field type"})); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse repeated field: %s success.", field->name().c_str()); - return SUCCESS; -} - void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index 9a3af8b..354f23e 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { */ Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, int index, std::vector &operators); - - /* - * @ingroup domi_omg - * @brief Parse message and set operator attrs - * @param [in] message, message of model - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info - * @return SUCCESS parse message successfully - * @return FAILED parse message failed - */ - Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); - - /* - * @ingroup domi_omg - * @brief Parse field and set operator attrs - * @param [in] reflection, reflection of message - * @param [in] message, message of model - * @param [in] field, field of message - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info - * @return SUCCESS parse field successfully - * @return FAILED parse field failed - */ - Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); - - /* - * @ingroup domi_omg - * @brief Parse repeated field and set operator attrs - * @param [in] reflection, reflection of message - * @param [in] message, message of model - * @param [in] field, field of message - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info by vector - * @return SUCCESS parse field successfully - * @return FAILED parse field failed - */ - Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); - /** * @ingroup domi_omg * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 218a1f4..a7178aa 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -566,8 +566,6 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge:: Status status = FAILED; domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); if (parse_param_func == nullptr) { - //std::shared_ptr onnx_op_parser = std::static_pointer_cast(op_parser); - //GE_CHECK_NOTNULL(onnx_op_parser); status = op_parser->ParseParams(node_proto, op); } else { ge::Operator op_src(node_proto->name(), op_type);