diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc index 466aeb39e6..0e231576cc 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc @@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto x = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 0); auto scale = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 1); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc index d80f865260..3745b28006 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc @@ -120,6 +120,10 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return false; } + auto shape = AnfAlgo::GetInputDeviceShape(node, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return false; + } auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); MS_EXCEPTION_IF_NULL(relu_grad); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc index 629dd17714..92faf0f325 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc @@ -49,6 +49,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto x = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 0); auto scale = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 1); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc index e8dc539591..eeff8830da 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc @@ -44,6 +44,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(node, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); MS_EXCEPTION_IF_NULL(relu_grad); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8cb1170ae6..808ea22373 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -344,6 +344,7 @@ const size_t kShape5dDims = 5; const size_t kShape1dDims = 1; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; +const size_t kBNChannelMultipleFactor = 4; const int kParameterDataTensorMask = 0; const int kParameterWeightTensorMask = 1; const int kValueNodeTensorMask = 2;