From: @alouhahahahaha Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -133,8 +133,22 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { | |||
| if (!output_size_list_.empty()) { | |||
| return output_size_list_; | |||
| } | |||
| for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { | |||
| if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { | |||
| auto cnode = anf_node_->cast<CNodePtr>(); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| int64_t rank_size = 1; | |||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { | |||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); | |||
| } | |||
| int64_t fusion = 0; | |||
| if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) { | |||
| fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||
| } | |||
| ulong loop_size = hccl_data_type_list_.size(); | |||
| if (op_name == kAllGatherOpName && fusion >= 1) { | |||
| loop_size *= rank_size; | |||
| } | |||
| for (ulong i = 0; i < loop_size; ++i) { | |||
| if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) { | |||
| MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; | |||
| } | |||
| output_size_list_.push_back(size); | |||
| @@ -127,7 +127,12 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp | |||
| total_size = total_size + block_size; | |||
| } else { | |||
| if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { | |||
| block_size = input_size; | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) { | |||
| block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; | |||
| } else { | |||
| block_size = input_size; | |||
| } | |||
| } else { | |||
| block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; | |||
| } | |||
| @@ -20,58 +20,33 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| void AddOutputs(const AnfNodePtr &node, int64_t rank_size) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto origin_abstract = node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(origin_abstract); | |||
| auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||
| auto &origin_abstracts = tuple_abstract->elements(); | |||
| AbstractBasePtrList abstract_list; | |||
| std::vector<TypeId> outputs_device_type; | |||
| std::vector<std::string> outputs_device_format; | |||
| for (int64_t i = 0; i < rank_size; ++i) { | |||
| for (size_t j = 0; j < origin_abstracts.size(); ++j) { | |||
| abstract_list.push_back(origin_abstracts[j]); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); | |||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); | |||
| } | |||
| } | |||
| // Update abstract | |||
| auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| node->set_abstract(new_abstracts); | |||
| // Update kernel build info | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||
| builder->SetOutputsDeviceType(outputs_device_type); | |||
| builder->SetOutputsFormat(outputs_device_format); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| } | |||
| } // namespace | |||
| AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::vector<AnfNodePtr> &new_tuple_getitems, | |||
| int64_t rank_size) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||
| size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t i = 0; i < inputs_size; ++i) { | |||
| for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) { | |||
| std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||
| std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||
| for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { | |||
| concat_inputs.push_back(new_tuple_getitems[idx]); | |||
| auto concat = func_graph->NewCNode(concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(concat); | |||
| MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); | |||
| concat->set_abstract(new_tuple_getitems[idx]->abstract()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||
| std::vector<int64_t> dyn_input_size{rank_size}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||
| kernel_select_->SelectKernel(concat); | |||
| make_tuple_inputs.push_back(concat); | |||
| } | |||
| auto concat = func_graph->NewCNode(concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(concat); | |||
| MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]); | |||
| auto dtypes = {AnfAlgo::GetOutputInferDataType(new_tuple_getitems[i], 0)}; | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(new_tuple_getitems[i], 0); | |||
| shape[0] *= rank_size; | |||
| std::vector<std::vector<size_t>> shapes = {shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), concat); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||
| std::vector<int64_t> dyn_input_size{rank_size}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||
| kernel_select_->SelectKernel(concat); | |||
| make_tuple_inputs.push_back(concat); | |||
| } | |||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| return make_tuple; | |||
| } | |||
| @@ -94,8 +69,11 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra | |||
| if (fusion <= 0) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::HasNodeAttr("fused", cnode)) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr("fused", MakeValue(true), node); | |||
| auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize); | |||
| AddOutputs(node, rank_size); | |||
| std::vector<AnfNodePtr> new_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | |||
| return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | |||
| @@ -46,15 +46,23 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||
| int64_t rank_size = 1; | |||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) { | |||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | |||
| inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||
| for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) { | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index); | |||
| shape[0] /= rank_size; | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||
| } | |||
| } | |||
| builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); | |||
| builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); | |||
| @@ -182,18 +190,27 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| fused_node->set_kernel_info(kernel_info); | |||
| AbstractBasePtrList abstract_list; | |||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| abstract_list.push_back(cnode->abstract()); | |||
| auto final_node = communication_op_info.communication_op_nodes[end_index]; | |||
| size_t node_num = end_index - start_index + 1; | |||
| int64_t rank_size = 1; | |||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) { | |||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize); | |||
| } | |||
| size_t output_num = node_num * rank_size; | |||
| std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0)); | |||
| std::vector<std::vector<size_t>> shapes; | |||
| for (size_t i = 0; i < IntToSize(rank_size); ++i) { | |||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||
| shape[0] /= rank_size; | |||
| shapes.push_back(shape); | |||
| } | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get()); | |||
| auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||
| fused_node->set_abstract(abstract_tuple); | |||
| auto final_node = communication_op_info.communication_op_nodes[end_index]; | |||
| AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); | |||
| AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); | |||
| AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); | |||