Browse Source

!1232 Fix batch norm bert fission for control depend case

Merge pull request !1232 from YuJianfeng/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
45d564d8f7
3 changed files with 28 additions and 7 deletions
  1. +14
    -3
      mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc
  2. +13
    -3
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
  3. +1
    -1
      mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc

+ 14
- 3
mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc View File

@@ -25,6 +25,7 @@ namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 3, 4, 5};
constexpr size_t kBatchNormRealOutputNum = 3;
constexpr size_t kBatchNormRealInputNum = 3;

bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
MS_EXCEPTION_IF_NULL(n1);
@@ -56,6 +57,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
for (const auto &node_index : manager->node_users()[bn]) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
@@ -77,7 +81,10 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1;
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
@@ -100,7 +107,10 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(bn);
auto bn_cnode = bn->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1;
}
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
<< ", but it is " << bn_training_reduce_outputs.size();
@@ -164,7 +174,8 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]);
output_index++;
}
return nullptr;
// Return the new node for control depends.
return bn_training_update_v2;
}
} // namespace opt
} // namespace mindspore

+ 13
- 3
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc View File

@@ -90,9 +90,19 @@ ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
}
auto tensor_ptr = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto *tensor_data = static_cast<float *>(tensor_ptr->data_c());
MS_EXCEPTION_IF_NULL(tensor_data);
return MakeValue(tensor_data[0]);
if (tensor_ptr->data_type() == kNumberTypeFloat16) {
auto *half_data = static_cast<const Eigen::half *>(tensor_ptr->data_c());
MS_EXCEPTION_IF_NULL(half_data);
float float_data = Eigen::half_impl::half_to_float(half_data[0]);
return MakeValue(float_data);
} else if (tensor_ptr->data_type() == kNumberTypeFloat32) {
auto *tensor_data = static_cast<const float *>(tensor_ptr->data_c());
MS_EXCEPTION_IF_NULL(tensor_data);
return MakeValue(tensor_data[0]);
} else {
MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32";
return nullptr;
}
}

AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,


+ 1
- 1
mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc View File

@@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimTupleGetItem->name()) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
return nullptr;
}
if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(),


Loading…
Cancel
Save