From 432192b1d85ed326d3ece33721a99ab8437d6295 Mon Sep 17 00:00:00 2001 From: wangcong Date: Tue, 2 Jun 2020 16:40:01 +0800 Subject: [PATCH] match ub fusion pass --- .../buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc | 6 ++---- .../ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc | 6 ++---- .../buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc | 7 ++----- .../ascend/buffer_fusion/eltwise_fusion_pass.cc | 1 + .../ascend/buffer_fusion/segment_eltwise_fusion_pass.cc | 1 + 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 8c4b1dcc63..118d566383 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -70,10 +70,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { auto eltwise_input = cnode->input(1); - if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { - MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { + MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); } } } diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc index 348504345a..6bb2a00087 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -65,10 +65,8 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { auto eltwise_input = cnode->input(1); - if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { - MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { + MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); } } } diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc index f485e901d8..98a6838bed 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc @@ -74,11 +74,8 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { auto eltwise_input = cnode->input(1); - if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { - if (eltwise_input->isa() && - AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); - } + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); } } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc index 42860de700..ca7bd4dc3a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc @@ -55,6 +55,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); for (auto &node : node_list) { if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc index 1926d64c61..0144c90a86 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc @@ -73,6 +73,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); for (auto &node : node_list) { if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {