|
|
|
@@ -27,15 +27,15 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, |
|
|
|
const session::KernelGraph &kernel_graph, |
|
|
|
FusedNodeRecord *candidate_fusion) { |
|
|
|
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input, |
|
|
|
const session::KernelGraph &kernel_graph, |
|
|
|
FusedNodeRecord *candidate_fusion) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
MS_EXCEPTION_IF_NULL(relu_input); |
|
|
|
auto getitem = relu_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(eltwise_input); |
|
|
|
auto getitem = eltwise_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem); |
|
|
|
auto bnupdate = getitem->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(bnupdate); |
|
|
|
@@ -68,10 +68,11 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && |
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { |
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && |
|
|
|
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) { |
|
|
|
auto eltwise_input = cnode->input(1); |
|
|
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { |
|
|
|
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); |
|
|
|
MatchBnupdateDoubleOutputEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|