From 8a7b56820363c577eb66f4d4f17078df8db3ac71 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Sat, 10 Apr 2021 18:21:11 +0800 Subject: [PATCH] add relu fusion check --- .../ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc | 6 ++++++ mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc index 2308ae6dab..147505b0e6 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc @@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A return nullptr; } + std::vector shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0); + std::vector shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1); + if (shape1 != shape2) { + return nullptr; + } + auto prim = std::make_shared(kFusedAddReluGradV2Name); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), x1, x2, mask}; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc index 1cea7134c4..d9a7d4f595 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc @@ -72,6 +72,12 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo return nullptr; } + std::vector shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0); + std::vector shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1); + if (shape1 != shape2) { + return nullptr; + } + auto prim = std::make_shared(kFusedAddReluV2Name); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), x1, x2};