| @@ -38,10 +38,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| } else { | } else { | ||||
| return; | return; | ||||
| } | } | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| auto double_in_eltwise_input = input_cnode->input(2); | auto double_in_eltwise_input = input_cnode->input(2); | ||||
| MS_EXCEPTION_IF_NULL(double_in_eltwise_input); | MS_EXCEPTION_IF_NULL(double_in_eltwise_input); | ||||
| std::vector<int64_t> conv2d_bp_output_used_num; | |||||
| if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) { | if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -50,6 +53,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| (void)record.insert(double_in_eltwise_input); | (void)record.insert(double_in_eltwise_input); | ||||
| candidate_fusion->push_back(record); | candidate_fusion->push_back(record); | ||||
| SetRecordFusionId(record); | SetRecordFusionId(record); | ||||
| conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size())); | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input); | |||||
| } else { | } else { | ||||
| auto double_in_eltwise_input_1 = input_cnode->input(1); | auto double_in_eltwise_input_1 = input_cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1); | MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1); | ||||
| @@ -61,8 +66,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| (void)record.insert(double_in_eltwise_input_1); | (void)record.insert(double_in_eltwise_input_1); | ||||
| candidate_fusion->push_back(record); | candidate_fusion->push_back(record); | ||||
| SetRecordFusionId(record); | SetRecordFusionId(record); | ||||
| conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input_1].size())); | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input_1); | |||||
| } | } | ||||
| } | } | ||||
| std::vector<int64_t> eltwise_output_used_num; | |||||
| eltwise_output_used_num.emplace_back(SizeToLong(manager->node_users()[input_cnode].size())); | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input); | |||||
| } | } | ||||
| void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | ||||