| @@ -102,6 +102,7 @@ | |||||
| #include "backend/optimizer/ascend/format_type/remove_internal_output.h" | #include "backend/optimizer/ascend/format_type/remove_internal_output.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | ||||
| #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -341,6 +342,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| auto other_pm = std::make_shared<PassManager>("other_pm"); | auto other_pm = std::make_shared<PassManager>("other_pm"); | ||||
| other_pm->AddPass(std::make_shared<AllReduceFusion>()); | other_pm->AddPass(std::make_shared<AllReduceFusion>()); | ||||
| other_pm->AddPass(std::make_shared<AllGatherFusion>()); | other_pm->AddPass(std::make_shared<AllGatherFusion>()); | ||||
| other_pm->AddPass(std::make_shared<ConcatOutputsForAllGather>()); | |||||
| other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | ||||
| other_pm->AddPass(std::make_shared<BroadcastFusion>()); | other_pm->AddPass(std::make_shared<BroadcastFusion>()); | ||||
| other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>()); | other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>()); | ||||
| @@ -0,0 +1,104 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| void AddOutputs(const AnfNodePtr &node, int rank_size) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto origin_abstract = node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(origin_abstract); | |||||
| auto tuple_abstract = origin_abstract->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||||
| auto &origin_abstracts = tuple_abstract->elements(); | |||||
| AbstractBasePtrList abstract_list; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::string> outputs_device_format; | |||||
| for (int i = 0; i < rank_size; ++i) { | |||||
| for (size_t j = 0; j < origin_abstracts.size(); ++j) { | |||||
| abstract_list.push_back(origin_abstracts[j]); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); | |||||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); | |||||
| } | |||||
| } | |||||
| // Update abstract | |||||
| auto new_abstracts = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| node->set_abstract(new_abstracts); | |||||
| // Update kernel build info | |||||
| auto builder = | |||||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||||
| builder->SetOutputsDeviceType(outputs_device_type); | |||||
| builder->SetOutputsFormat(outputs_device_format); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, | |||||
| int rank_size) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | |||||
| for (size_t i = 0; i < inputs_size; ++i) { | |||||
| for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { | |||||
| std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||||
| concat_inputs.push_back(new_tuple_getitems[idx]); | |||||
| auto concat = func_graph->NewCNode(concat_inputs); | |||||
| MS_EXCEPTION_IF_NULL(concat); | |||||
| MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); | |||||
| concat->set_abstract(new_tuple_getitems[idx]->abstract()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||||
| std::vector<int> dyn_input_size{rank_size}; | |||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||||
| kernel_select_->SelectKernel(concat); | |||||
| make_tuple_inputs.push_back(concat); | |||||
| } | |||||
| } | |||||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| return make_tuple; | |||||
| } | |||||
| const BaseRef ConcatOutputsForAllGather::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| auto prim = std::make_shared<Primitive>(kAllGatherOpName); | |||||
| return VectorRef({prim, Xs}); | |||||
| } | |||||
| const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto fusion = AnfAlgo::GetNodeAttr<int>(cnode, kAttrFusion); | |||||
| if (fusion <= 0) { | |||||
| return nullptr; | |||||
| } | |||||
| auto rank_size = AnfAlgo::GetNodeAttr<int>(node, kAttrRankSize); | |||||
| AddOutputs(node, rank_size); | |||||
| std::vector<AnfNodePtr> new_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | |||||
| return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ConcatOutputsForAllGather : public PatternProcessPass { | |||||
| public: | |||||
| explicit ConcatOutputsForAllGather(bool multigraph = true) | |||||
| : PatternProcessPass("concat_outputs_for_all_gather", multigraph), | |||||
| kernel_select_(std::make_shared<KernelSelect>()) {} | |||||
| ~ConcatOutputsForAllGather() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, int rank_size) const; | |||||
| KernelSelectPtr kernel_select_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ | |||||
| @@ -188,9 +188,13 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | ||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | MS_EXCEPTION_IF_NULL(abstract_tuple); | ||||
| fused_node->set_abstract(abstract_tuple); | fused_node->set_abstract(abstract_tuple); | ||||
| AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node); | |||||
| auto final_node = communication_op_info.communication_op_nodes[end_index]; | |||||
| AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); | |||||
| AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); | |||||
| AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node)) { | |||||
| AnfAlgo::CopyNodeAttr(kAttrRankSize, final_node, fused_node); | |||||
| } | |||||
| return fused_node; | return fused_node; | ||||
| } | } | ||||
| @@ -250,6 +250,7 @@ constexpr auto kAttrChildGraph = "child_graph"; | |||||
| constexpr auto kAttrInputNums = "inputNums"; | constexpr auto kAttrInputNums = "inputNums"; | ||||
| constexpr auto kAttrT = "T"; | constexpr auto kAttrT = "T"; | ||||
| constexpr auto kAttrNum = "num"; | constexpr auto kAttrNum = "num"; | ||||
| constexpr auto kAttrRankSize = "rank_size"; | |||||
| // attr value | // attr value | ||||
| constexpr auto kValueTargetSwitch = "target_switch"; | constexpr auto kValueTargetSwitch = "target_switch"; | ||||