Browse Source

!1996 Fix single BatchNorm fission && SoftmaxGradExt fusion pass

Merge pull request !1996 from huanghui/single-batchnorm-fission-pass
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3b809f2b0b
7 changed files with 31 additions and 9 deletions
  1. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
  2. +10
    -6
      mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc
  3. +4
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc
  4. +11
    -0
      mindspore/ccsrc/pre_activate/common/helper.cc
  5. +3
    -0
      mindspore/ccsrc/pre_activate/common/helper.h
  6. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py
  7. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc View File

@@ -128,7 +128,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_

std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
std::string convert_format;
size_t counter = 0;
const size_t counter = 0;
for (const auto &iter : format_counter) {
if (counter < iter.second) {
convert_format = iter.first;


+ 10
- 6
mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc View File

@@ -129,18 +129,22 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kBatchNormRealInputNum + 1) {
if (cnode->size() < kBatchNormRealInputNum + 1) {
MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum
<< ". The node should not be changed";
return nullptr;
}
if (!GetBoolAttr(cnode, kAttrIsTraining)) {
MS_LOG(INFO) << "is training should be true if do fusion";
return nullptr;
}
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
std::vector<AnfNodePtr> bn_training_reduce_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,


+ 4
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc View File

@@ -58,6 +58,10 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const
auto input1 = GetAnfNodeByVar(equiv, input1_);
auto input2 = GetAnfNodeByVar(equiv, input2_);
auto sum = GetAnfNodeByVar(equiv, sum_var_);
if (!GetBoolAttr(sum, kAttrKeepDims)) {
MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion";
return nullptr;
}

auto prim = std::make_shared<Primitive>(kSoftmaxGradExtOpName);
auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2});


+ 11
- 0
mindspore/ccsrc/pre_activate/common/helper.cc View File

@@ -722,5 +722,16 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
MS_EXCEPTION_IF_NULL(value_node2);
return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value());
}

bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(INFO) << "node is not a cnode";
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
}
} // namespace opt
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/pre_activate/common/helper.h View File

@@ -180,6 +180,9 @@ AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node);

// Compare tuple getitem's index, return bool[n1's index < n2's index]
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);

// Get attr which is bool from cnode
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_

+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py View File

@@ -17,7 +17,7 @@ from mindspore.ops import operations as P

make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
BatchNorm = P.BatchNorm()
BatchNorm = P.BatchNorm(is_training=True)
BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')



+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py View File

@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P

Mul = P.Mul()
ReduceSum = P.ReduceSum()
ReduceSum = P.ReduceSum(keep_dims=True)
Sub = P.Sub()
SoftmaxGradExt = Primitive('SoftmaxGradExt')
MakeTuple = Primitive('make_tuple')


Loading…
Cancel
Save