|
|
|
@@ -27,6 +27,20 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
constexpr size_t kEltwiseInputSize = 2; |
|
|
|
constexpr size_t kEltwiseOutputSize = 2; |
|
|
|
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) { |
|
|
|
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, |
|
|
|
const session::KernelGraph &kernel_graph, |
|
|
|
FusedNodeRecord *candidate_fusion) { |
|
|
|
@@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && |
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { |
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { |
|
|
|
auto eltwise_input = cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(eltwise_input); |
|
|
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { |
|
|
|
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); |
|
|
|
} |
|
|
|
|