|
|
|
@@ -546,7 +546,7 @@ void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr |
|
|
|
} |
|
|
|
|
|
|
|
void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, |
|
|
|
FusedNodeRecord *candidate_fusion, bool is_order) { |
|
|
|
FusedNodeRecord *candidate_fusion, bool is_order) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
@@ -595,7 +595,7 @@ void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, |
|
|
|
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion); |
|
|
|
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) { |
|
|
|
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion); |
|
|
|
} else if (relu_input->isa<CNode>() && |
|
|
|
} else if (relu_input->isa<CNode>() && |
|
|
|
AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) { |
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); |
|
|
|
} |
|
|
|
|