From 7026f98edb01eedb9453ed233293c9fe67e57f74 Mon Sep 17 00:00:00 2001 From: wjm Date: Mon, 7 Jun 2021 05:12:05 +0800 Subject: [PATCH] custom register --- parser/onnx/onnx_custom_parser_adapter.cc | 11 +++++ parser/onnx/onnx_custom_parser_adapter.h | 2 + parser/onnx/onnx_parser.cc | 44 ++++++++++++++++--- parser/onnx/onnx_parser.h | 2 + .../onnx_parser_unittest.cc | 7 ++- 5 files changed, 59 insertions(+), 7 deletions(-) diff --git a/parser/onnx/onnx_custom_parser_adapter.cc b/parser/onnx/onnx_custom_parser_adapter.cc index 3b7e7b0..b3e112d 100644 --- a/parser/onnx/onnx_custom_parser_adapter.cc +++ b/parser/onnx/onnx_custom_parser_adapter.cc @@ -21,6 +21,7 @@ #include "register/op_registry.h" using domi::ParseParamFunc; +using domi::ParseParamByOpFunc; using domi::ONNX; namespace ge{ @@ -40,5 +41,15 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator return SUCCESS; } +Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) { + ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); + GE_CHECK_NOTNULL(custom_op_parser); + + GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op_dest) == SUCCESS, FAILED, + "[Invoke][CustomOpParser] failed, node name:%s, type:%s", + op_src.GetName().c_str(), op_src.GetOpType().c_str()); + return SUCCESS; +} + REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); } // namespace ge diff --git a/parser/onnx/onnx_custom_parser_adapter.h b/parser/onnx/onnx_custom_parser_adapter.h index 1e5f147..7e0fb06 100644 --- a/parser/onnx/onnx_custom_parser_adapter.h +++ b/parser/onnx/onnx_custom_parser_adapter.h @@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser { /// @return SUCCESS parse successfully /// @return FAILED parse failed Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; + + Status ParseParams(const Operator &op_src, Operator &op_dest); }; } // namespace ge diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index cc69799..eef2041 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -36,6 +36,7 @@ #include "parser/common/model_saver.h" #include "parser/common/parser_utils.h" #include "parser/common/prototype_pass_manager.h" +#include "parser/onnx/onnx_custom_parser_adapter.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" @@ -555,6 +556,41 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { return SUCCESS; } +Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, + std::shared_ptr &op_parser) { + GE_CHECK_NOTNULL(node_proto); + GE_CHECK_NOTNULL(op_parser); + std::string op_type = node_proto->op_type(); + + Status status = FAILED; + domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); + if (parse_param_func == nullptr) { + //std::shared_ptr onnx_op_parser = std::static_pointer_cast(op_parser); + //GE_CHECK_NOTNULL(onnx_op_parser); + status = op_parser->ParseParams(node_proto, op); + } else { + ge::Operator op_src(node_proto->name(), op_type); + /*status = domi::AutoMappingFn(node_def, op_src); + if (status != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", + node_def->name().c_str(), node_def->op().c_str()); + GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str()); + return status; + }*/ + std::shared_ptr onnx_custom_op_parser = + std::dynamic_pointer_cast(op_parser); + status = onnx_custom_op_parser->ParseParams(op_src, op); + } + + if (status != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type}); + GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str()); + return status; + } + + return SUCCESS; +} + Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); @@ -586,19 +622,15 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: GE_CHECK_NOTNULL(factory); std::shared_ptr op_parser = factory->CreateOpParser(op_type); GE_CHECK_NOTNULL(op_parser); - std::shared_ptr onnx_op_parser = std::static_pointer_cast(op_parser); - GE_CHECK_NOTNULL(onnx_op_parser); - status = onnx_op_parser->ParseParams(node_proto, op); + status = ParseOpParam(node_proto, op, op_parser); if (status != SUCCESS) { - REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); - GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); + GELOGE(status, "Parse params for node[%s] failed", node_name.c_str()); return status; } GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); - ge::graphStatus graph_status = graph.AddOp(op); if (graph_status != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index b28494b..b90c1a3 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { void ClearMembers(); + Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr &op_parser); + Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, std::map &name_to_onnx_graph); diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index a3c54d5..ae4cb49 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& return SUCCESS; } +static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { + return SUCCESS; +} + Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); @@ -77,7 +81,8 @@ void UtestOnnxParser::RegisterCustomOp() { REGISTER_CUSTOM_OP("Add") .FrameworkType(domi::ONNX) .OriginOpType("ai.onnx::11::Add") - .ParseParamsFn(ParseParams); + .ParseParamsFn(ParseParams) + .ParseParamsByOperatorFn(ParseParamByOpFunc); REGISTER_CUSTOM_OP("Identity") .FrameworkType(domi::ONNX)