diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc index 0b023c3691..be8aa21854 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc @@ -25,6 +25,7 @@ namespace opt { namespace { const std::vector kOutputIndex{0, 3, 4, 5}; constexpr size_t kBatchNormRealOutputNum = 3; +constexpr size_t kBatchNormRealInputNum = 3; bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { MS_EXCEPTION_IF_NULL(n1); @@ -56,6 +57,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s for (const auto &node_index : manager->node_users()[bn]) { AnfNodePtr output = node_index.first; MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } auto tuple_getiterm_cnode = output->cast(); MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); @@ -77,7 +81,10 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP MS_EXCEPTION_IF_NULL(bn); auto bn_cnode = bn->cast(); MS_EXCEPTION_IF_NULL(bn_cnode); - CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } std::vector bn_training_reduce_inputs = { NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); @@ -100,7 +107,10 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod MS_EXCEPTION_IF_NULL(bn); auto bn_cnode = bn->cast(); MS_EXCEPTION_IF_NULL(bn_cnode); - CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum << ", but it is " << bn_training_reduce_outputs.size(); @@ -164,7 +174,8 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); output_index++; } - return nullptr; + // Return the new node for control depends. + return bn_training_update_v2; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc index de8a7d9d55..8e48780936 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -90,9 +90,19 @@ ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { } auto tensor_ptr = value->cast(); MS_EXCEPTION_IF_NULL(tensor_ptr); - auto *tensor_data = static_cast(tensor_ptr->data_c()); - MS_EXCEPTION_IF_NULL(tensor_data); - return MakeValue(tensor_data[0]); + if (tensor_ptr->data_type() == kNumberTypeFloat16) { + auto *half_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(half_data); + float float_data = Eigen::half_impl::half_to_float(half_data[0]); + return MakeValue(float_data); + } else if (tensor_ptr->data_type() == kNumberTypeFloat32) { + auto *tensor_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(tensor_data); + return MakeValue(tensor_data[0]); + } else { + MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32"; + return nullptr; + } } AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc index 93c1b73038..66b3dc1d88 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc @@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimTupleGetItem->name()) { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { return nullptr; } if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(),