Browse Source

!8867 [bugfix]fused batch norm op's input channel nums should be a multiple of 4

From: @zyli2020
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6d7e934a52
5 changed files with 17 additions and 0 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc
  2. +4
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc
  3. +4
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc
  4. +4
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc
  5. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 4
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc View File

@@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);


+ 4
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc View File

@@ -120,6 +120,10 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return false;
}
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return false;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);


+ 4
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc View File

@@ -49,6 +49,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);


+ 4
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc View File

@@ -44,6 +44,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -344,6 +344,7 @@ const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16;
const size_t kMemAlignSize = 512;
const size_t kBNChannelMultipleFactor = 4;
const int kParameterDataTensorMask = 0;
const int kParameterWeightTensorMask = 1;
const int kValueNodeTensorMask = 2;


Loading…
Cancel
Save