Browse Source

fix conv2dbackprop_eltwise_eltwise_fusion_pass which lack output_used_num

pull/15224/head
yuchaojie 4 years ago
parent
commit
e745a90b9a
1 changed files with 10 additions and 0 deletions
  1. +10
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc

+ 10
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc View File

@@ -38,10 +38,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
} else { } else {
return; return;
} }
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto input_cnode = eltwise_input->cast<CNodePtr>(); auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode); MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(2); auto double_in_eltwise_input = input_cnode->input(2);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input); MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
std::vector<int64_t> conv2d_bp_output_used_num;
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) { if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
return; return;
} }
@@ -50,6 +53,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
(void)record.insert(double_in_eltwise_input); (void)record.insert(double_in_eltwise_input);
candidate_fusion->push_back(record); candidate_fusion->push_back(record);
SetRecordFusionId(record); SetRecordFusionId(record);
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input);
} else { } else {
auto double_in_eltwise_input_1 = input_cnode->input(1); auto double_in_eltwise_input_1 = input_cnode->input(1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1); MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
@@ -61,8 +66,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
(void)record.insert(double_in_eltwise_input_1); (void)record.insert(double_in_eltwise_input_1);
candidate_fusion->push_back(record); candidate_fusion->push_back(record);
SetRecordFusionId(record); SetRecordFusionId(record);
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input_1].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input_1);
} }
} }
std::vector<int64_t> eltwise_output_used_num;
eltwise_output_used_num.emplace_back(SizeToLong(manager->node_users()[input_cnode].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
} }


void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,


Loading…
Cancel
Save