| @@ -27,6 +27,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| constexpr size_t kEltwiseInputSize = 2; | |||||
| constexpr size_t kEltwiseOutputSize = 2; | |||||
| bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) { | |||||
| if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) { | |||||
| return true; | |||||
| } | |||||
| if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, | void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, | ||||
| const session::KernelGraph &kernel_graph, | const session::KernelGraph &kernel_graph, | ||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| @@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | ||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { | |||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { | |||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | |||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { | if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { | ||||
| MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | ||||
| } | } | ||||