From: @zyli2020 Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristovaltags/v1.1.0
| @@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons | |||||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | ||||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | ||||
| @@ -120,6 +120,10 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { | if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto shape = AnfAlgo::GetInputDeviceShape(node, 0); | |||||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||||
| return false; | |||||
| } | |||||
| auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | ||||
| MS_EXCEPTION_IF_NULL(relu_grad); | MS_EXCEPTION_IF_NULL(relu_grad); | ||||
| @@ -49,6 +49,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A | |||||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | ||||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | ||||
| @@ -44,6 +44,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con | |||||
| if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { | if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto shape = AnfAlgo::GetInputDeviceShape(node, 0); | |||||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | ||||
| MS_EXCEPTION_IF_NULL(relu_grad); | MS_EXCEPTION_IF_NULL(relu_grad); | ||||
| @@ -344,6 +344,7 @@ const size_t kShape5dDims = 5; | |||||
| const size_t kShape1dDims = 1; | const size_t kShape1dDims = 1; | ||||
| const size_t kCubeSize = 16; | const size_t kCubeSize = 16; | ||||
| const size_t kMemAlignSize = 512; | const size_t kMemAlignSize = 512; | ||||
| const size_t kBNChannelMultipleFactor = 4; | |||||
| const int kParameterDataTensorMask = 0; | const int kParameterDataTensorMask = 0; | ||||
| const int kParameterWeightTensorMask = 1; | const int kParameterWeightTensorMask = 1; | ||||
| const int kValueNodeTensorMask = 2; | const int kValueNodeTensorMask = 2; | ||||