|
|
|
@@ -38,10 +38,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw |
|
|
|
} else { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto input_cnode = eltwise_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode); |
|
|
|
auto double_in_eltwise_input = input_cnode->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(double_in_eltwise_input); |
|
|
|
std::vector<int64_t> conv2d_bp_output_used_num; |
|
|
|
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -50,6 +53,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw |
|
|
|
(void)record.insert(double_in_eltwise_input); |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
SetRecordFusionId(record); |
|
|
|
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size())); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input); |
|
|
|
} else { |
|
|
|
auto double_in_eltwise_input_1 = input_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1); |
|
|
|
@@ -61,8 +66,13 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw |
|
|
|
(void)record.insert(double_in_eltwise_input_1); |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
SetRecordFusionId(record); |
|
|
|
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input_1].size())); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input_1); |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<int64_t> eltwise_output_used_num; |
|
|
|
eltwise_output_used_num.emplace_back(SizeToLong(manager->node_users()[input_cnode].size())); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input); |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, |
|
|
|
|