| @@ -182,18 +182,34 @@ AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const | |||||
| return device_num_reciprocal_value; | return device_num_reciprocal_value; | ||||
| } | } | ||||
| AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) { | |||||
| if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { | |||||
| AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input}); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get()); | |||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||||
| cast->set_scope(input->scope()); | |||||
| return cast; | |||||
| } | |||||
| return input; | |||||
| } | |||||
| AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, | AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, | ||||
| const CNodePtr &sync_bn_cnode) { | const CNodePtr &sync_bn_cnode) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(allreduce_input); | MS_EXCEPTION_IF_NULL(allreduce_input); | ||||
| MS_EXCEPTION_IF_NULL(sync_bn_cnode); | MS_EXCEPTION_IF_NULL(sync_bn_cnode); | ||||
| // Cast input to fp32, this can reduce the number of cast node. Since the input of AllReduce, | |||||
| // BNTrainingReduce/BNTrainingUpdateGrad op only support fp32 output, when inferred output is fp16, it will | |||||
| // insert cast: output_fp32->cast_fp16->allreduce&mul->cast_fp32. Add this cast can eliminate above cast. | |||||
| // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output. | |||||
| AnfNodePtr input_node = InsertCast(graph, allreduce_input, kNumberTypeFloat32); | |||||
| // create AllReduce | // create AllReduce | ||||
| std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), | |||||
| allreduce_input}; | |||||
| std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node}; | |||||
| auto allreduce = graph->NewCNode(allreduce_inputs); | auto allreduce = graph->NewCNode(allreduce_inputs); | ||||
| MS_EXCEPTION_IF_NULL(allreduce); | MS_EXCEPTION_IF_NULL(allreduce); | ||||
| allreduce->set_abstract(allreduce_input->abstract()); | |||||
| allreduce->set_abstract(input_node->abstract()); | |||||
| allreduce->set_scope(allreduce_input->scope()); | allreduce->set_scope(allreduce_input->scope()); | ||||
| AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce); | AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce); | ||||
| AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce); | AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce); | ||||
| @@ -216,9 +232,12 @@ AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &al | |||||
| device_num_reciprocal_vnode}; | device_num_reciprocal_vnode}; | ||||
| auto mul = graph->NewCNode(mul_inputs); | auto mul = graph->NewCNode(mul_inputs); | ||||
| MS_EXCEPTION_IF_NULL(mul); | MS_EXCEPTION_IF_NULL(mul); | ||||
| mul->set_abstract(allreduce_input->abstract()); | |||||
| mul->set_abstract(input_node->abstract()); | |||||
| mul->set_scope(allreduce_input->scope()); | mul->set_scope(allreduce_input->scope()); | ||||
| return mul; | |||||
| // Cast output to origin datatype to reduce the number of cast node. | |||||
| // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output. | |||||
| return InsertCast(graph, mul, AnfAlgo::GetOutputInferDataType(allreduce_input, 0)); | |||||
| } | } | ||||
| const BaseRef BnSplit::DefinePattern() const { | const BaseRef BnSplit::DefinePattern() const { | ||||