|
|
@@ -17,6 +17,7 @@ |
|
|
#include "onnx_parser.h" |
|
|
#include "onnx_parser.h" |
|
|
#include <algorithm> |
|
|
#include <algorithm> |
|
|
#include <iostream> |
|
|
#include <iostream> |
|
|
|
|
|
#include <queue> |
|
|
#include "common/convert/pb2json.h" |
|
|
#include "common/convert/pb2json.h" |
|
|
#include "common/util.h" |
|
|
#include "common/util.h" |
|
|
#include "common/util/error_manager/error_manager.h" |
|
|
#include "common/util/error_manager/error_manager.h" |
|
|
@@ -37,6 +38,9 @@ |
|
|
#include "parser/onnx/onnx_util.h" |
|
|
#include "parser/onnx/onnx_util.h" |
|
|
#include "register/op_registry.h" |
|
|
#include "register/op_registry.h" |
|
|
#include "register/register_fmk_types.h" |
|
|
#include "register/register_fmk_types.h" |
|
|
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
|
|
#include "subgraph_adapter/subgraph_adapter_factory.h" |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
@@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
// parse caffe model_file to GE graph |
|
|
|
|
|
|
|
|
// parse onnx model_file to GE graph |
|
|
ge::graphStatus ret = model_parser->Parse(model_file, graph); |
|
|
ge::graphStatus ret = model_parser->Parse(model_file, graph); |
|
|
if (ret != ge::SUCCESS) { |
|
|
if (ret != ge::SUCCESS) { |
|
|
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); |
|
|
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); |
|
|
@@ -144,18 +148,130 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, |
|
|
namespace ge { |
|
|
namespace ge { |
|
|
namespace { |
|
|
namespace { |
|
|
const std::map<std::string, std::string> kOnnxOpMap = { |
|
|
const std::map<std::string, std::string> kOnnxOpMap = { |
|
|
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, |
|
|
|
|
|
|
|
|
{ge::kOpTypeInput, ge::parser::DATA}, |
|
|
|
|
|
{ge::kOpTypeConstant, ge::parser::CONSTANT} |
|
|
}; |
|
|
}; |
|
|
const char* const MATMULV2 = "MatMulV2"; |
|
|
|
|
|
const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2}; |
|
|
|
|
|
const int64_t kDimValue = 1; |
|
|
const int64_t kDimValue = 1; |
|
|
|
|
|
|
|
|
|
|
|
struct ParseArg { |
|
|
|
|
|
ge::onnx::GraphProto *onnx_graph; |
|
|
|
|
|
ge::NodePtr parent_node; |
|
|
|
|
|
std::string graph_name; |
|
|
|
|
|
uint32_t subgraph_index; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) { |
|
|
|
|
|
GELOGI("Gen subgraph parse tasks start"); |
|
|
|
|
|
for (auto &node : parent_graph->GetDirectNode()) { |
|
|
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
|
|
for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { |
|
|
|
|
|
auto i = subgraph_name_to_index.second; |
|
|
|
|
|
auto subgraph_iname = subgraph_name_to_index.first; |
|
|
|
|
|
if (subgraph_iname.empty()) { |
|
|
|
|
|
GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str()); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// change the graph name to ensure it is unique in GE |
|
|
|
|
|
std::string unique_subgraph_name; |
|
|
|
|
|
OnnxUtil::GenUniqueSubgraphName(i, subgraph_iname, node->GetName(), unique_subgraph_name); |
|
|
|
|
|
|
|
|
|
|
|
GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s", |
|
|
|
|
|
node->GetName().c_str(), i, unique_subgraph_name.c_str()); |
|
|
|
|
|
args.push_back({nullptr, node, unique_subgraph_name, i}); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
GELOGI("Gen subgraph parse tasks end"); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status BuildLinkForChildAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) { |
|
|
|
|
|
if (arg.parent_node == nullptr) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
auto parent_node = arg.parent_node; |
|
|
|
|
|
auto index = arg.subgraph_index; |
|
|
|
|
|
auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[Set][Subgraph] Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), |
|
|
|
|
|
parent_node->GetName().c_str(), index); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), |
|
|
|
|
|
parent_node->GetName().c_str(), index); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
|
|
|
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) { |
|
|
|
|
|
if (onnx_graph.input_size() == 0) { |
|
|
|
|
|
|
|
|
Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) { |
|
|
|
|
|
if (arg.parent_node == nullptr) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
std::string op_type = arg.parent_node->GetType(); |
|
|
|
|
|
std::string op_name = arg.parent_node->GetName(); |
|
|
|
|
|
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; |
|
|
|
|
|
auto post_func = |
|
|
|
|
|
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); |
|
|
|
|
|
if (post_func == nullptr) { |
|
|
|
|
|
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); |
|
|
|
|
|
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { |
|
|
|
|
|
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), |
|
|
|
|
|
arg.parent_node->GetType().c_str()); |
|
|
|
|
|
|
|
|
|
|
|
// Refresh node_name in subgraph |
|
|
|
|
|
for (const ge::NodePtr &node : sub_graph->GetDirectNode()) { |
|
|
|
|
|
if (node->GetOpDesc() == nullptr) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); |
|
|
|
|
|
Status ret = FAILED; |
|
|
|
|
|
if (post_func != nullptr) { |
|
|
|
|
|
ret = post_func(arg.graph_name, graph); |
|
|
|
|
|
} else if (parse_func_v2 != nullptr) { |
|
|
|
|
|
ret = parse_func_v2(arg.graph_name.c_str(), graph); |
|
|
|
|
|
} |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(FAILED, "[PostProcess][Subgraph]Failed to post-process subgraph %s on node %s type %s", |
|
|
|
|
|
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Failed to post-process subgraph %s on node %s type %s", |
|
|
|
|
|
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto &onnx_graph) { |
|
|
|
|
|
if (onnx_graph.output_size() == 0) { |
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16001"); |
|
|
|
|
|
GELOGE(FAILED, "[Parse][Output] Onnx graph:%s has zero output", onnx_graph.name().c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Onnx graph:%s has zero output", onnx_graph.name().c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// get output value info map |
|
|
|
|
|
for (int i = 0; i < onnx_graph.output_size(); i++) { |
|
|
|
|
|
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); |
|
|
|
|
|
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); |
|
|
|
|
|
output_node_names_.emplace_back(value_info.name()); |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::ParseInput(const std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor, |
|
|
|
|
|
bool is_subgraph, ge::onnx::GraphProto &onnx_graph) { |
|
|
|
|
|
if (!is_subgraph && onnx_graph.input_size() == 0) { |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16001"); |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16001"); |
|
|
GELOGE(FAILED, "Onnx graph has zero input"); |
|
|
|
|
|
|
|
|
GELOGE(FAILED, "[Parse][Input] Root onnx graph:%s has zero input", onnx_graph.name().c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Root onnx graph:%s has zero input", onnx_graph.name().c_str()); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -207,6 +323,11 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); |
|
|
ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); |
|
|
attribute_index->set_name(ge::kAttrNameIndex); |
|
|
attribute_index->set_name(ge::kAttrNameIndex); |
|
|
attribute_index->set_i(data_index++); |
|
|
attribute_index->set_i(data_index++); |
|
|
|
|
|
// add subgraph attr |
|
|
|
|
|
if (is_subgraph) { |
|
|
|
|
|
attribute = input_node->add_attribute(); |
|
|
|
|
|
attribute->set_name(ge::kAttrNameIsSubgraphOp); |
|
|
|
|
|
} |
|
|
input_node_names_.emplace_back(value_info.name()); |
|
|
input_node_names_.emplace_back(value_info.name()); |
|
|
} |
|
|
} |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
@@ -319,7 +440,8 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot |
|
|
op = ge::OperatorFactory::CreateOperator(node_name, op_type); |
|
|
op = ge::OperatorFactory::CreateOperator(node_name, op_type); |
|
|
if (op.GetName() != node_name) { |
|
|
if (op.GetName() != node_name) { |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); |
|
|
GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); |
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[Creat][Op] IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); |
|
|
return INTERNAL_ERROR; |
|
|
return INTERNAL_ERROR; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -428,7 +550,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
std::string op_type; |
|
|
std::string op_type; |
|
|
Status status = AdapterOpType(node_proto, ori_type, op_type); |
|
|
Status status = AdapterOpType(node_proto, ori_type, op_type); |
|
|
if (status != SUCCESS) { |
|
|
if (status != SUCCESS) { |
|
|
GELOGE(status, "Adapter op type for ori type %s failed.", ori_type.c_str()); |
|
|
|
|
|
|
|
|
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
node_proto->set_op_type(ori_type); |
|
|
node_proto->set_op_type(ori_type); |
|
|
@@ -438,7 +561,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
ge::Operator op; |
|
|
ge::Operator op; |
|
|
status = TransNodeToOperator(node_proto, op, op_type); |
|
|
status = TransNodeToOperator(node_proto, op, op_type); |
|
|
if (status != SUCCESS) { |
|
|
if (status != SUCCESS) { |
|
|
GELOGE(status, "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); |
|
|
|
|
|
|
|
|
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -455,9 +579,14 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), |
|
|
|
|
|
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ge::graphStatus graph_status = graph.AddOp(op); |
|
|
ge::graphStatus graph_status = graph.AddOp(op); |
|
|
if (graph_status != ge::GRAPH_SUCCESS) { |
|
|
if (graph_status != ge::GRAPH_SUCCESS) { |
|
|
GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str()); |
|
|
|
|
|
|
|
|
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", op.GetName().c_str()); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
name_operator_[op.GetName()] = op; |
|
|
name_operator_[op.GetName()] = op; |
|
|
@@ -473,18 +602,54 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { |
|
|
|
|
|
|
|
|
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { |
|
|
|
|
|
if (input_node_names_.empty()) { |
|
|
|
|
|
// subgraph might not have input, we use constant nodes as the start nodes of graph |
|
|
|
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
|
|
|
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); |
|
|
|
|
|
if (node->op_type() == kOpTypeConstant) { |
|
|
|
|
|
input_node_names_.emplace_back(node->name()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
for (auto in_name : input_node_names_) { |
|
|
for (auto in_name : input_node_names_) { |
|
|
auto in_op = name_operator_.find(in_name); |
|
|
auto in_op = name_operator_.find(in_name); |
|
|
if (in_op == name_operator_.end()) { |
|
|
if (in_op == name_operator_.end()) { |
|
|
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", |
|
|
|
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", |
|
|
in_name.c_str()); |
|
|
in_name.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Model assigned input node name: %s can not find in graph.", |
|
|
|
|
|
in_name.c_str()); |
|
|
return PARAM_INVALID; |
|
|
return PARAM_INVALID; |
|
|
} |
|
|
} |
|
|
input_ops.emplace_back(in_op->second); |
|
|
input_ops.emplace_back(in_op->second); |
|
|
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); |
|
|
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); |
|
|
} |
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops) { |
|
|
|
|
|
for (auto output_name : output_node_names_) { |
|
|
|
|
|
auto itr = outputs_map_.find(output_name); |
|
|
|
|
|
if (itr == outputs_map_.end()) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR( "E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; |
|
|
|
|
|
for (const auto &node_name_index : node_names_indexes) { |
|
|
|
|
|
auto node_name = node_name_index.first; |
|
|
|
|
|
auto out_op_itr = name_operator_.find(node_name); |
|
|
|
|
|
if (out_op_itr == name_operator_.end()) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
int index = node_name_index.second; |
|
|
|
|
|
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); |
|
|
|
|
|
GELOGI("out node index %d, node:%s", index, node_name.c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -515,19 +680,146 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge:: |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { |
|
|
|
|
|
|
|
|
void OnnxModelParser::ClearMembers() { |
|
|
|
|
|
name_operator_.clear(); |
|
|
|
|
|
input_node_names_.clear(); |
|
|
|
|
|
output_node_names_.clear(); |
|
|
|
|
|
inputs_map_.clear(); |
|
|
|
|
|
outputs_map_.clear(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, |
|
|
|
|
|
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { |
|
|
|
|
|
std::queue<ge::onnx::GraphProto *> onnx_graph_tasks; |
|
|
|
|
|
int index = 0; |
|
|
|
|
|
onnx_graph_tasks.push(&root_onnx_graph); |
|
|
|
|
|
|
|
|
|
|
|
while (!onnx_graph_tasks.empty()) { |
|
|
|
|
|
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); |
|
|
|
|
|
onnx_graph_tasks.pop(); |
|
|
|
|
|
for (int i = 0; i < onnx_graph->node_size(); i++) { |
|
|
|
|
|
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); |
|
|
|
|
|
if (node_proto->name().empty()) { |
|
|
|
|
|
std::string node_name = node_proto->op_type() + "_" + to_string(index++); |
|
|
|
|
|
node_proto->set_name(node_name); |
|
|
|
|
|
} |
|
|
|
|
|
GELOGD("adapt op name:%s, op type:%s", node_proto->name().c_str(), node_proto->op_type().c_str()); |
|
|
|
|
|
|
|
|
|
|
|
SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance(); |
|
|
|
|
|
GE_CHECK_NOTNULL(factory); |
|
|
|
|
|
std::shared_ptr<SubgraphAdapter> subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type()); |
|
|
|
|
|
if(subgraph_adapter == nullptr) { |
|
|
|
|
|
GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str()); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<ge::onnx::GraphProto *> onnx_graphs; |
|
|
|
|
|
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; |
|
|
|
|
|
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { |
|
|
|
|
|
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (const auto &onnx_graph : onnx_graphs) { |
|
|
|
|
|
onnx_graph_tasks.push(onnx_graph); |
|
|
|
|
|
} |
|
|
|
|
|
for (const auto &itr : name_to_onnx_subgraph) { |
|
|
|
|
|
name_to_onnx_graph.emplace(itr.first, itr.second); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) { |
|
|
if (!onnx_model.has_graph()) { |
|
|
if (!onnx_model.has_graph()) { |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16004"); |
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16004"); |
|
|
GELOGE(PARAM_INVALID, "Onnx model do not has graph."); |
|
|
GELOGE(PARAM_INVALID, "Onnx model do not has graph."); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
ge::onnx::GraphProto onnx_graph = onnx_model.graph(); |
|
|
|
|
|
|
|
|
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_graph; |
|
|
|
|
|
std::deque<ParseArg> tasks; |
|
|
|
|
|
ge::onnx::GraphProto root_onnx_graph = onnx_model.graph(); |
|
|
|
|
|
|
|
|
|
|
|
auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(FAILED, "[AdaptAndFind][OnnxGraph]adapt and find all onnx graph failed, root graph:%s.", |
|
|
|
|
|
root_onnx_graph.name().c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto opset_import = onnx_model.opset_import(); |
|
|
auto opset_import = onnx_model.opset_import(); |
|
|
for (auto it : opset_import) { |
|
|
for (auto it : opset_import) { |
|
|
domain_verseion_[it.domain()] = it.version(); |
|
|
domain_verseion_[it.domain()] = it.version(); |
|
|
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); |
|
|
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); |
|
|
} |
|
|
} |
|
|
|
|
|
std::string root_graph_name = root_graph.GetName().empty() ? "default_graph" : root_graph.GetName(); |
|
|
|
|
|
tasks.push_back({&root_onnx_graph, nullptr, root_graph_name, 0}); |
|
|
|
|
|
|
|
|
|
|
|
while (!tasks.empty()) { |
|
|
|
|
|
ParseArg arg = tasks.front(); |
|
|
|
|
|
tasks.pop_front(); |
|
|
|
|
|
bool is_subgraph = (arg.parent_node != nullptr) ? true : false; |
|
|
|
|
|
|
|
|
|
|
|
if (arg.onnx_graph == nullptr) { |
|
|
|
|
|
auto itr = name_to_onnx_graph.find(arg.graph_name); |
|
|
|
|
|
if (itr == name_to_onnx_graph.end()) { |
|
|
|
|
|
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
arg.onnx_graph = itr->second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ge::onnx::GraphProto *onnx_graph = arg.onnx_graph; |
|
|
|
|
|
ge::Graph tmp_graph(arg.graph_name); |
|
|
|
|
|
ret = ModelParseToGraphImpl(is_subgraph, *onnx_graph, tmp_graph); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[Parse][Model] Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); |
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
// To get the result for root graph |
|
|
|
|
|
if (!is_subgraph) { |
|
|
|
|
|
root_graph = tmp_graph; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ge::ComputeGraphPtr cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph); |
|
|
|
|
|
GE_CHECK_NOTNULL(cur_compute_graph); |
|
|
|
|
|
|
|
|
|
|
|
ret = PostOpProcessForSubgraph(arg, cur_compute_graph); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[PostProcess][Subgraph]Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ret = BuildLinkForChildAndParentGraph(cur_compute_graph, arg); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[BuildLink][Graph] Build link for child graph:%s and parent graph failed.", |
|
|
|
|
|
cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Build link for child graph:%s and parent graph failed.", |
|
|
|
|
|
cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ret = GenSubgraphParseTasks(cur_compute_graph, tasks); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[Generate][Task] Failed to gen tasks on graph %s for next iteration", |
|
|
|
|
|
cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph %s for next iteration", |
|
|
|
|
|
cur_compute_graph->GetName().c_str()); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
UpdateDataFormat(root_graph); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { |
|
|
|
|
|
|
|
|
|
|
|
ClearMembers(); |
|
|
|
|
|
|
|
|
// 2. Get all inializer. |
|
|
// 2. Get all inializer. |
|
|
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; |
|
|
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; |
|
|
@@ -541,7 +833,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
|
|
|
|
|
// 3. Parse Input from graph. |
|
|
// 3. Parse Input from graph. |
|
|
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); |
|
|
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); |
|
|
Status ret = ParseInput(onnx_graph, initializer_name_tensor); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
GELOGE(ret, "Parse input for onnx failed."); |
|
|
GELOGE(ret, "Parse input for onnx failed."); |
|
|
return ret; |
|
|
return ret; |
|
|
@@ -555,6 +848,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ret = ParseOutput(onnx_graph); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[Parse][Output] Parse output for onnx failed."); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// 5. Update node name for node do not has name. |
|
|
// 5. Update node name for node do not has name. |
|
|
ret = UpdateAllNodeName(onnx_graph); |
|
|
ret = UpdateAllNodeName(onnx_graph); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
@@ -582,6 +881,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<string> op_names; |
|
|
|
|
|
graph.GetAllOpName(op_names); |
|
|
|
|
|
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); |
|
|
|
|
|
|
|
|
// 8. Set all operator input. |
|
|
// 8. Set all operator input. |
|
|
ret = SetOperatorInputs(); |
|
|
ret = SetOperatorInputs(); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
@@ -589,22 +892,27 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<string> op_names; |
|
|
|
|
|
graph.GetAllOpName(op_names); |
|
|
|
|
|
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); |
|
|
|
|
|
|
|
|
|
|
|
// 9. Construct graph. |
|
|
// 9. Construct graph. |
|
|
std::vector<ge::Operator> input_ops; |
|
|
std::vector<ge::Operator> input_ops; |
|
|
|
|
|
|
|
|
ret = GetGraphInputs(input_ops); |
|
|
|
|
|
|
|
|
ret = GetGraphInputs(onnx_graph, input_ops); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
GELOGE(ret, "Get graph inputs failed."); |
|
|
GELOGE(ret, "Get graph inputs failed."); |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
graph.SetInputs(input_ops); |
|
|
graph.SetInputs(input_ops); |
|
|
|
|
|
|
|
|
|
|
|
// root graph needn't set outputs. |
|
|
|
|
|
if(is_subgraph) { |
|
|
|
|
|
std::vector<std::pair<Operator, std::vector<size_t>>> output_ops; |
|
|
|
|
|
ret = GetGraphOutputs(output_ops); |
|
|
|
|
|
if (ret != SUCCESS) { |
|
|
|
|
|
GELOGE(ret, "[Get][Outputs]Get graph outputs failed."); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
graph.SetOutputs(output_ops); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); |
|
|
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); |
|
|
UpdateDataFormat(graph); |
|
|
|
|
|
|
|
|
|
|
|
GELOGI("Onnx model parser success."); |
|
|
GELOGI("Onnx model parser success."); |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
|