Browse Source

Fix bnupdate_eltwise_eltwise ub fusion pass

tags/v1.0.0
yujianfeng 5 years ago
parent
commit
c4bbf5a282
1 changed files with 16 additions and 1 deletions
  1. +16
    -1
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc

+ 16
- 1
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc View File

@@ -27,6 +27,20 @@

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kEltwiseInputSize = 2;
constexpr size_t kEltwiseOutputSize = 2;
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
return true;
}
if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) {
return true;
}
return false;
}
} // namespace

void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
@@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}


Loading…
Cancel
Save