|
|
|
@@ -74,11 +74,8 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker |
|
|
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && |
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { |
|
|
|
auto eltwise_input = cnode->input(1); |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { |
|
|
|
if (eltwise_input->isa<CNode>() && |
|
|
|
AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { |
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); |
|
|
|
} |
|
|
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { |
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { |
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); |
|
|
|
|