Browse Source

!14913 add relu fusion check

From: @wilfchen
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
pull/14913/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
888a2565bc
2 changed files with 12 additions and 0 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc
  2. +6
    -0
      mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc

+ 6
- 0
mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc View File

@@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
return nullptr;
}

std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
if (shape1 != shape2) {
return nullptr;
}

auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask};


+ 6
- 0
mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc View File

@@ -72,6 +72,12 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
return nullptr;
}

std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
if (shape1 != shape2) {
return nullptr;
}

auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2};


Loading…
Cancel
Save