From a7c992853124ca776fdabd5ffcfb62371cc6342f Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 12 Dec 2020 17:10:20 +0800 Subject: [PATCH] modified: ge/graph/passes/compile_nodes_pass.cc modified: ge/graph/passes/folding_pass.cc modified: ge/graph/passes/unused_const_pass.cc modified: inc/framework/common/debug/log.h --- ge/graph/passes/compile_nodes_pass.cc | 14 +++++++------- ge/graph/passes/folding_pass.cc | 17 +++++++++++------ ge/graph/passes/unused_const_pass.cc | 7 ++++++- inc/framework/common/debug/log.h | 4 ++-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc index 1ed9caf0..6b0b9305 100755 --- a/ge/graph/passes/compile_nodes_pass.cc +++ b/ge/graph/passes/compile_nodes_pass.cc @@ -70,8 +70,8 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { kernel_to_compile_nodes.insert(std::make_pair(kernel_lib_name, node_vec)); } } else { - GELOGE(GRAPH_FAILED, "Get node:%s, type:%s supported kernel failed.", node->GetName().c_str(), - node->GetType().c_str()); + GE_ERRORLOG_AND_ERRORMSG(GRAPH_FAILED, "Get node:%s, type:%s supported kernel failed.", node->GetName().c_str(), + node->GetType().c_str()); return GRAPH_FAILED; } } @@ -99,8 +99,8 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); kernel_lib_name = op_desc->GetOpKernelLibName(); if (kernel_lib_name.empty()) { - GELOGE(GRAPH_FAILED, "Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), - op_desc->GetType().c_str()); + GE_ERRORLOG_AND_ERRORMSG(GRAPH_FAILED, "Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), + op_desc->GetType().c_str()); return GRAPH_FAILED; } } @@ -130,8 +130,8 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: return GRAPH_SUCCESS; } } - GELOGE(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", - node->GetName().c_str(), op_desc->GetType().c_str()); + GE_ERRORLOG_AND_ERRORMSG(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", + node->GetName().c_str(), op_desc->GetType().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; @@ -173,7 +173,7 @@ graphStatus CompileNodesPass::CompileNodes(const std::shared_ptr instance } auto ret = kernel_info->CompileOp(kernel_nodes.second); if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "Compile op failed, kernel name is %s", kernel_nodes.first.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ret, "Compile op failed, kernel name is %s", kernel_nodes.first.c_str()); return GRAPH_FAILED; } } diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index 93dc2c40..b1bd5a61 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -173,10 +173,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { continue; } auto in_node = in_node_anchor->GetOwnerNode(); - if (in_node == nullptr) { - continue; - } - if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { + if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { @@ -188,7 +185,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { node->GetName().c_str()); auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); auto identity = - AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); + AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); if (identity == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add identity node to graph."); return INTERNAL_ERROR; @@ -241,6 +238,14 @@ Status FoldingPass::AddConstNode(NodePtr &node, IndexsToAnchors indexes_to_ancho node->GetName().c_str(), index); return INTERNAL_ERROR; } + + vector curr_origin_op_names; + (void)AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, curr_origin_op_names); + if (curr_origin_op_names.empty()) { + (void)AttrUtils::SetListStr(const_node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()}); + } else { + (void)AttrUtils::SetListStr(const_node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, curr_origin_op_names); + } GELOGI("add const_node:%s, replace node %s, type %s, index %zu.", const_node->GetName().c_str(), node->GetName().c_str(), node->GetType().c_str(), index); // add new const to re-pass node @@ -328,4 +333,4 @@ Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &n AddRePassNodesWithInOut(node); return SUCCESS; } -} // namespace ge +} // namespace ge \ No newline at end of file diff --git a/ge/graph/passes/unused_const_pass.cc b/ge/graph/passes/unused_const_pass.cc index 7c57c53e..8f94fc85 100644 --- a/ge/graph/passes/unused_const_pass.cc +++ b/ge/graph/passes/unused_const_pass.cc @@ -40,6 +40,11 @@ Status UnusedConstPass::Run(NodePtr &node) { GELOGD("op type is unused const."); return IsolateAndDeleteNode(node, {-1}); } + // remove those const which only has control-in and control-out + if ((op_type == CONSTANT || op_type == CONSTANTOP) && (node->GetOutDataNodesSize() == 0)) { + GELOGD("Remove unused const %s.", node->GetName().c_str()); + return IsolateAndDeleteNode(node, {-1}); + } return SUCCESS; } -} // namespace ge +} // namespace ge \ No newline at end of file diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index 249271a6..b55cc28c 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -258,7 +258,7 @@ #define GE_ERRORLOG_AND_ERRORMSG(_status, errormsg) \ { \ GELOGE(_status, "%s", errormsg); \ - ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ + ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {errormsg}); \ } #define GE_CHK_LOG_AND_ERRORMSG(expr, _status, errormsg) \ @@ -266,7 +266,7 @@ bool b = (expr); \ if (!b) { \ GELOGE(_status, "%s", errormsg); \ - ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ + ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {errormsg}); \ return _status; \ } \ } while (0)