From: @alouhahahahaha Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -133,8 +133,22 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { | |||||
| if (!output_size_list_.empty()) { | if (!output_size_list_.empty()) { | ||||
| return output_size_list_; | return output_size_list_; | ||||
| } | } | ||||
| for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { | |||||
| if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { | |||||
| auto cnode = anf_node_->cast<CNodePtr>(); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| int64_t rank_size = 1; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { | |||||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); | |||||
| } | |||||
| int64_t fusion = 0; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) { | |||||
| fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||||
| } | |||||
| ulong loop_size = hccl_data_type_list_.size(); | |||||
| if (op_name == kAllGatherOpName && fusion >= 1) { | |||||
| loop_size *= rank_size; | |||||
| } | |||||
| for (ulong i = 0; i < loop_size; ++i) { | |||||
| if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) { | |||||
| MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; | MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; | ||||
| } | } | ||||
| output_size_list_.push_back(size); | output_size_list_.push_back(size); | ||||
| @@ -127,7 +127,12 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp | |||||
| total_size = total_size + block_size; | total_size = total_size + block_size; | ||||
| } else { | } else { | ||||
| if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { | if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { | ||||
| block_size = input_size; | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) { | |||||
| block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; | |||||
| } else { | |||||
| block_size = input_size; | |||||
| } | |||||
| } else { | } else { | ||||
| block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; | block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; | ||||
| } | } | ||||
| @@ -20,58 +20,33 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| void AddOutputs(const AnfNodePtr &node, int64_t rank_size) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto origin_abstract = node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(origin_abstract); | |||||
| auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||||
| auto &origin_abstracts = tuple_abstract->elements(); | |||||
| AbstractBasePtrList abstract_list; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::string> outputs_device_format; | |||||
| for (int64_t i = 0; i < rank_size; ++i) { | |||||
| for (size_t j = 0; j < origin_abstracts.size(); ++j) { | |||||
| abstract_list.push_back(origin_abstracts[j]); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); | |||||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); | |||||
| } | |||||
| } | |||||
| // Update abstract | |||||
| auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| node->set_abstract(new_abstracts); | |||||
| // Update kernel build info | |||||
| auto builder = | |||||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||||
| builder->SetOutputsDeviceType(outputs_device_type); | |||||
| builder->SetOutputsFormat(outputs_device_format); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, | const std::vector<AnfNodePtr> &new_tuple_getitems, | ||||
| int64_t rank_size) const { | int64_t rank_size) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||||
| size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | ||||
| for (size_t i = 0; i < inputs_size; ++i) { | for (size_t i = 0; i < inputs_size; ++i) { | ||||
| for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) { | |||||
| std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||||
| std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||||
| for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { | |||||
| concat_inputs.push_back(new_tuple_getitems[idx]); | concat_inputs.push_back(new_tuple_getitems[idx]); | ||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| MS_EXCEPTION_IF_NULL(concat); | |||||
| MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); | |||||
| concat->set_abstract(new_tuple_getitems[idx]->abstract()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||||
| std::vector<int64_t> dyn_input_size{rank_size}; | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||||
| kernel_select_->SelectKernel(concat); | |||||
| make_tuple_inputs.push_back(concat); | |||||
| } | } | ||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| MS_EXCEPTION_IF_NULL(concat); | |||||
| MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]); | |||||
| auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)}; | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0); | |||||
| shape[0] *= rank_size; | |||||
| std::vector<std::vector<size_t>> shapes = {shape}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||||
| std::vector<int64_t> dyn_input_size{rank_size}; | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||||
| kernel_select_->SelectKernel(concat); | |||||
| make_tuple_inputs.push_back(concat); | |||||
| } | } | ||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | ||||
| return make_tuple; | return make_tuple; | ||||
| } | } | ||||
| @@ -94,8 +69,11 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra | |||||
| if (fusion <= 0) { | if (fusion <= 0) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (AnfAlgo::HasNodeAttr("fused", cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr("fused", MakeValue(true), node); | |||||
| auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize); | auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize); | ||||
| AddOutputs(node, rank_size); | |||||
| std::vector<AnfNodePtr> new_outputs; | std::vector<AnfNodePtr> new_outputs; | ||||
| CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | ||||
| return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | ||||
| @@ -46,15 +46,23 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | ||||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | for (size_t idx = start_index; idx <= end_index; ++idx) { | ||||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | auto cnode = communication_op_info.communication_op_nodes[idx]; | ||||
| int64_t rank_size = 1; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) { | |||||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | ||||
| inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | ||||
| inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | ||||
| } | } | ||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||||
| for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) { | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index); | |||||
| shape[0] /= rank_size; | |||||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||||
| } | |||||
| } | } | ||||
| builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); | builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); | ||||
| builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); | builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); | ||||
| @@ -182,18 +190,27 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| fused_node->set_kernel_info(kernel_info); | fused_node->set_kernel_info(kernel_info); | ||||
| AbstractBasePtrList abstract_list; | |||||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| abstract_list.push_back(cnode->abstract()); | |||||
| auto final_node = communication_op_info.communication_op_nodes[end_index]; | |||||
| size_t node_num = end_index - start_index + 1; | |||||
| int64_t rank_size = 1; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) { | |||||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize); | |||||
| } | |||||
| size_t output_num = node_num * rank_size; | |||||
| std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0)); | |||||
| std::vector<std::vector<size_t>> shapes; | |||||
| for (size_t i = 0; i < IntToSize(rank_size); ++i) { | |||||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||||
| shape[0] /= rank_size; | |||||
| shapes.push_back(shape); | |||||
| } | |||||
| } | } | ||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get()); | |||||
| auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); | auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); | ||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||||
| fused_node->set_abstract(abstract_tuple); | |||||
| auto final_node = communication_op_info.communication_op_nodes[end_index]; | |||||
| AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); | AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); | ||||
| AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); | AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); | ||||
| AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); | AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); | ||||