Browse Source

custom register

pull/321/head
wjm 4 years ago
parent
commit
7026f98edb
5 changed files with 59 additions and 7 deletions
  1. +11
    -0
      parser/onnx/onnx_custom_parser_adapter.cc
  2. +2
    -0
      parser/onnx/onnx_custom_parser_adapter.h
  3. +38
    -6
      parser/onnx/onnx_parser.cc
  4. +2
    -0
      parser/onnx/onnx_parser.h
  5. +6
    -1
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 11
- 0
parser/onnx/onnx_custom_parser_adapter.cc View File

@@ -21,6 +21,7 @@
#include "register/op_registry.h"

using domi::ParseParamFunc;
using domi::ParseParamByOpFunc;
using domi::ONNX;

namespace ge{
@@ -40,5 +41,15 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator
return SUCCESS;
}

Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) {
ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType());
GE_CHECK_NOTNULL(custom_op_parser);

GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op_dest) == SUCCESS, FAILED,
"[Invoke][CustomOpParser] failed, node name:%s, type:%s",
op_src.GetName().c_str(), op_src.GetOpType().c_str());
return SUCCESS;
}

REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter);
} // namespace ge

+ 2
- 0
parser/onnx/onnx_custom_parser_adapter.h View File

@@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser {
/// @return SUCCESS parse successfully
/// @return FAILED parse failed
Status ParseParams(const Message *op_src, ge::Operator &op_dest) override;

Status ParseParams(const Operator &op_src, Operator &op_dest);
};
} // namespace ge



+ 38
- 6
parser/onnx/onnx_parser.cc View File

@@ -36,6 +36,7 @@
#include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h"
#include "parser/common/prototype_pass_manager.h"
#include "parser/onnx/onnx_custom_parser_adapter.h"
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"
#include "register/register_fmk_types.h"
@@ -555,6 +556,41 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) {
return SUCCESS;
}

Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op,
std::shared_ptr<OpParser> &op_parser) {
GE_CHECK_NOTNULL(node_proto);
GE_CHECK_NOTNULL(op_parser);
std::string op_type = node_proto->op_type();

Status status = FAILED;
domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type);
if (parse_param_func == nullptr) {
//std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser);
//GE_CHECK_NOTNULL(onnx_op_parser);
status = op_parser->ParseParams(node_proto, op);
} else {
ge::Operator op_src(node_proto->name(), op_type);
/*status = domi::AutoMappingFn(node_def, op_src);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed",
node_def->name().c_str(), node_def->op().c_str());
GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str());
return status;
}*/
std::shared_ptr<ge::OnnxCustomParserAdapter> onnx_custom_op_parser =
std::dynamic_pointer_cast<ge::OnnxCustomParserAdapter>(op_parser);
status = onnx_custom_op_parser->ParseParams(op_src, op);
}

if (status != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type});
GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str());
return status;
}

return SUCCESS;
}

Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
@@ -586,19 +622,15 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
GE_CHECK_NOTNULL(factory);
std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type);
GE_CHECK_NOTNULL(op_parser);
std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser);
GE_CHECK_NOTNULL(onnx_op_parser);
status = onnx_op_parser->ParseParams(node_proto, op);
status = ParseOpParam(node_proto, op, op_parser);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
GELOGE(status, "Parse params for node[%s] failed", node_name.c_str());
return status;
}

GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(),
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize());


ge::graphStatus graph_status = graph.AddOp(op);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str());


+ 2
- 0
parser/onnx/onnx_parser.h View File

@@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

void ClearMembers();

Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr<OpParser> &op_parser);

Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);



+ 6
- 1
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator&
return SUCCESS;
}

static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
return SUCCESS;
}

Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) {
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func =
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX);
@@ -77,7 +81,8 @@ void UtestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Add")
.ParseParamsFn(ParseParams);
.ParseParamsFn(ParseParams)
.ParseParamsByOperatorFn(ParseParamByOpFunc);

REGISTER_CUSTOM_OP("Identity")
.FrameworkType(domi::ONNX)


Loading…
Cancel
Save