|
|
|
@@ -36,6 +36,7 @@ |
|
|
|
#include "parser/common/model_saver.h" |
|
|
|
#include "parser/common/parser_utils.h" |
|
|
|
#include "parser/common/prototype_pass_manager.h" |
|
|
|
#include "parser/onnx/onnx_custom_parser_adapter.h" |
|
|
|
#include "parser/onnx/onnx_util.h" |
|
|
|
#include "register/op_registry.h" |
|
|
|
#include "register/register_fmk_types.h" |
|
|
|
@@ -555,6 +556,41 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, |
|
|
|
std::shared_ptr<OpParser> &op_parser) { |
|
|
|
GE_CHECK_NOTNULL(node_proto); |
|
|
|
GE_CHECK_NOTNULL(op_parser); |
|
|
|
std::string op_type = node_proto->op_type(); |
|
|
|
|
|
|
|
Status status = FAILED; |
|
|
|
domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); |
|
|
|
if (parse_param_func == nullptr) { |
|
|
|
//std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser); |
|
|
|
//GE_CHECK_NOTNULL(onnx_op_parser); |
|
|
|
status = op_parser->ParseParams(node_proto, op); |
|
|
|
} else { |
|
|
|
ge::Operator op_src(node_proto->name(), op_type); |
|
|
|
/*status = domi::AutoMappingFn(node_def, op_src); |
|
|
|
if (status != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", |
|
|
|
node_def->name().c_str(), node_def->op().c_str()); |
|
|
|
GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str()); |
|
|
|
return status; |
|
|
|
}*/ |
|
|
|
std::shared_ptr<ge::OnnxCustomParserAdapter> onnx_custom_op_parser = |
|
|
|
std::dynamic_pointer_cast<ge::OnnxCustomParserAdapter>(op_parser); |
|
|
|
status = onnx_custom_op_parser->ParseParams(op_src, op); |
|
|
|
} |
|
|
|
|
|
|
|
if (status != SUCCESS) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type}); |
|
|
|
GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str()); |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { |
|
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
|
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); |
|
|
|
@@ -586,19 +622,15 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
|
GE_CHECK_NOTNULL(factory); |
|
|
|
std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type); |
|
|
|
GE_CHECK_NOTNULL(op_parser); |
|
|
|
std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser); |
|
|
|
GE_CHECK_NOTNULL(onnx_op_parser); |
|
|
|
status = onnx_op_parser->ParseParams(node_proto, op); |
|
|
|
status = ParseOpParam(node_proto, op, op_parser); |
|
|
|
if (status != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); |
|
|
|
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); |
|
|
|
GELOGE(status, "Parse params for node[%s] failed", node_name.c_str()); |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), |
|
|
|
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); |
|
|
|
|
|
|
|
|
|
|
|
ge::graphStatus graph_status = graph.AddOp(op); |
|
|
|
if (graph_status != ge::GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); |
|
|
|
|