diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc index d03e29b6c4..245186ea47 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -182,18 +182,34 @@ AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const return device_num_reciprocal_value; } +AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) { + if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { + AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared(kCastOpName)), input}); + AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get()); + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); + cast->set_scope(input->scope()); + return cast; + } + return input; +} + AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, const CNodePtr &sync_bn_cnode) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(allreduce_input); MS_EXCEPTION_IF_NULL(sync_bn_cnode); + // Cast input to fp32, this can reduce the number of cast node. Since the input of AllReduce, + // BNTrainingReduce/BNTrainingUpdateGrad op only support fp32 output, when inferred output is fp16, it will + // insert cast: output_fp32->cast_fp16->allreduce&mul->cast_fp32. Add this cast can eliminate above cast. + // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output. + AnfNodePtr input_node = InsertCast(graph, allreduce_input, kNumberTypeFloat32); + // create AllReduce - std::vector allreduce_inputs = {NewValueNode(std::make_shared(kAllReduceOpName)), - allreduce_input}; + std::vector allreduce_inputs = {NewValueNode(std::make_shared(kAllReduceOpName)), input_node}; auto allreduce = graph->NewCNode(allreduce_inputs); MS_EXCEPTION_IF_NULL(allreduce); - allreduce->set_abstract(allreduce_input->abstract()); + allreduce->set_abstract(input_node->abstract()); allreduce->set_scope(allreduce_input->scope()); AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce); AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce); @@ -216,9 +232,12 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al device_num_reciprocal_vnode}; auto mul = graph->NewCNode(mul_inputs); MS_EXCEPTION_IF_NULL(mul); - mul->set_abstract(allreduce_input->abstract()); + mul->set_abstract(input_node->abstract()); mul->set_scope(allreduce_input->scope()); - return mul; + + // Cast output to origin datatype to reduce the number of cast node. + // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output. + return InsertCast(graph, mul, AnfAlgo::GetOutputInferDataType(allreduce_input, 0)); } const BaseRef BnSplit::DefinePattern() const {