Browse Source

!1786 match fusion type

Merge pull request !1786 from wangcong/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5cba231ba9
5 changed files with 8 additions and 13 deletions
  1. +2
    -4
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc
  2. +2
    -4
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc
  3. +2
    -5
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc
  4. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc
  5. +1
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc

+ 2
- 4
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc View File

@@ -70,10 +70,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
}
}


+ 2
- 4
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc View File

@@ -65,10 +65,8 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
}
}


+ 2
- 5
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc View File

@@ -74,11 +74,8 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
if (eltwise_input->isa<CNode>() &&
AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
}
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
}
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc View File

@@ -55,6 +55,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {


+ 1
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc View File

@@ -73,6 +73,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {


Loading…
Cancel
Save