Browse Source

modify the condition of pattern match in bnupdate + eltwise fusion pass

tags/v0.7.0-beta
etone-chan 5 years ago
parent
commit
18c83637f1
3 changed files with 11 additions and 9 deletions
  1. +8
    -7
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc
  2. +2
    -2
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h
  3. +1
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h

+ 8
- 7
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc View File

@@ -27,15 +27,15 @@

namespace mindspore {
namespace opt {
void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(relu_input);
auto getitem = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(eltwise_input);
auto getitem = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
@@ -68,10 +68,11 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
auto eltwise_input = cnode->input(1);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
MatchBnupdateDoubleOutputEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
}
}


+ 2
- 2
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h View File

@@ -39,8 +39,8 @@ class BnupdateEltwiseFusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

private:
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore


+ 1
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h View File

@@ -33,6 +33,7 @@ const int8_t MAX_ELTWISE_NUM = 3;
const int8_t MIN_ELTWISE_SIZE = 2;
const int8_t ELTWISE_INPUT_SIZE = 2;
const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
const int8_t CONV_QUART_IN_INPUT_SIZE = 5;
const int8_t ELTWISE_USE = 1;


Loading…
Cancel
Save