From: @tronzhang Reviewed-by: @gaoxiong1,@anyrenwei Signed-off-by: @anyrenweipull/14364/MERGE
| @@ -15,12 +15,15 @@ | |||||
| */ | */ | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| #include <memory> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| @@ -1,166 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/depend_formater.h" | |||||
| #include <tuple> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/kernel_compiler/common_utils.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { | |||||
| const auto &users = mng->node_users()[node]; | |||||
| std::vector<std::pair<AnfNodePtr, int>> sons; | |||||
| for (const auto &[user, index] : users) { | |||||
| if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { | |||||
| sons.emplace_back(user, index); | |||||
| continue; | |||||
| } | |||||
| auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); | |||||
| sons.emplace_back(fake_first_grad_son, grad_index); | |||||
| } | |||||
| AnfNodePtrList latter_to_delete; | |||||
| for (const auto &[son, index] : sons) { | |||||
| if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| latter_to_delete.push_back(son); | |||||
| } | |||||
| if (latter_to_delete.empty()) { | |||||
| return false; | |||||
| } | |||||
| std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin(); | |||||
| if (latter_to_delete.size() == sons.size()) { | |||||
| // Left one Depend node relation and delete others! | |||||
| ++delete_begin; | |||||
| } | |||||
| for (; delete_begin != latter_to_delete.end(); ++delete_begin) { | |||||
| auto depend_anfnode = *delete_begin; | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtr patron_node; | |||||
| auto return_cnode = main_graph->get_return()->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(return_cnode); | |||||
| auto output_node = return_cnode->input(kFirstDataInputIndex); | |||||
| if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { | |||||
| auto output_cnode = output_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||||
| patron_node = output_cnode->input(kFirstDataInputIndex); | |||||
| } else { | |||||
| patron_node = output_node; | |||||
| } | |||||
| return patron_node; | |||||
| } | |||||
| void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph, | |||||
| const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtr modified_node = stable_node; | |||||
| for (const auto &free_node : free_nodes) { | |||||
| AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node}; | |||||
| auto depend_cnode = main_graph->NewCNode(d_inputs); | |||||
| depend_cnode->set_abstract(modified_node->abstract()); | |||||
| main_graph->AddNode(depend_cnode); | |||||
| modified_node = depend_cnode; | |||||
| } | |||||
| if (!free_nodes.empty()) { | |||||
| mng->Replace(stable_node, modified_node); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| bool DependFormater::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto mng = func_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(func_graph, true); | |||||
| func_graph->set_manager(mng); | |||||
| } | |||||
| // 1. Try to remove redundant depend. | |||||
| bool changed = false; | |||||
| auto nodes = TopoSort(func_graph->get_return()); | |||||
| std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void { | |||||
| if (HasAbstractMonad(node)) { | |||||
| return; | |||||
| } | |||||
| if (RemoveRedundantDepend(node, mng)) { | |||||
| changed = true; | |||||
| } | |||||
| }); | |||||
| // Should re-toposort for changed graph. | |||||
| if (changed) { | |||||
| nodes = TopoSort(func_graph->get_return()); | |||||
| } | |||||
| // 2. Move depend to tail of graph. | |||||
| AnfNodePtrList old_depends; | |||||
| AnfNodePtrList free_nodes; | |||||
| // Find depend and its free nodes. | |||||
| for (const auto &node : nodes) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend) || | |||||
| HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) { | |||||
| continue; | |||||
| } | |||||
| old_depends.push_back(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) { | |||||
| auto attach_node = cnode->input(id); | |||||
| if (!IsPrimitiveCNode(attach_node, prim::kPrimDepend)) { | |||||
| continue; | |||||
| } | |||||
| free_nodes.push_back(attach_node); | |||||
| } | |||||
| } | |||||
| if (old_depends.empty()) { | |||||
| return changed; | |||||
| } | |||||
| // Delete old depend. | |||||
| for (const auto &depend_anfnode : old_depends) { | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| // Add new depend node in tail. | |||||
| AnfNodePtr patron_node = FindPatronNode(func_graph, mng); | |||||
| AddDepends(patron_node, free_nodes, func_graph, mng); | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class DependFormater : public Pass { | |||||
| public: | |||||
| DependFormater() : Pass("depend_formater") {} | |||||
| ~DependFormater() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| }; | |||||
| using DependFormaterPtr = std::shared_ptr<DependFormater>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_ | |||||
| @@ -26,7 +26,6 @@ | |||||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | #include "backend/optimizer/graph_kernel/clean_all_in_once.h" | ||||
| #include "backend/optimizer/graph_kernel/depend_formater.h" | |||||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | ||||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | #include "backend/optimizer/graph_kernel/tensor_promotion.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | ||||
| @@ -50,9 +49,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { | |||||
| // Change Assign(p, a, U) to Assign(Depend(p, U), a) | // Change Assign(p, a, U) to Assign(Depend(p, U), a) | ||||
| pm->AddPass(std::make_shared<SplitAssign>()); | pm->AddPass(std::make_shared<SplitAssign>()); | ||||
| // Move the Depend nodes to the bottom of graph | |||||
| pm->AddPass(std::make_shared<DependFormater>()); | |||||
| // Reorder TransData-Cast to Cast-TransData, | // Reorder TransData-Cast to Cast-TransData, | ||||
| if (is_ascend) { | if (is_ascend) { | ||||
| pm->AddPass(std::make_shared<ReorderOps>()); | pm->AddPass(std::make_shared<ReorderOps>()); | ||||
| @@ -142,8 +138,6 @@ PassManagerPtr GraphKernelOptimizer::Combine() { | |||||
| auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine"); | auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine"); | ||||
| // Enable parallel fusion | // Enable parallel fusion | ||||
| if (is_gpu) { | if (is_gpu) { | ||||
| // Prevent fake loop in parallel fusion | |||||
| pm->AddPass(std::make_shared<DependFormater>()); | |||||
| // Do parallel fusion for gpu device | // Do parallel fusion for gpu device | ||||
| pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7))); | pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7))); | ||||
| } | } | ||||
| @@ -142,10 +142,9 @@ void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||||
| for (const auto &getitem_user_iter : getitem_users) { | for (const auto &getitem_user_iter : getitem_users) { | ||||
| auto getitem_user = getitem_user_iter.first; | auto getitem_user = getitem_user_iter.first; | ||||
| // 1. A previous pass `DependFormater` has ensured that all data users are directly link to its | |||||
| // input, without Depend node. | |||||
| // 1. Data users may not link directly to its input, they may segregated by Depend node. | |||||
| // 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add update_state and load node to | // 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add update_state and load node to | ||||
| // keep exec_order. | |||||
| // keep exec_order. | |||||
| if (HasPathToParamUser(cnode, getitem_user, getitem)) { | if (HasPathToParamUser(cnode, getitem_user, getitem)) { | ||||
| mng->Replace(getitem, assign_to); | mng->Replace(getitem, assign_to); | ||||
| continue; | continue; | ||||
| @@ -95,33 +95,6 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn, | |||||
| } | } | ||||
| } | } | ||||
| void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||||
| OrderedSet<AnfNodePtr> to_be_through_pass; | |||||
| for (auto &[node, node_rel] : (*node_rels)) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend) || | |||||
| HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) { | |||||
| continue; | |||||
| } | |||||
| // Make attached nodes deattach with node. | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) { | |||||
| auto attach_node = cnode->input(id); | |||||
| if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) { | |||||
| iter->second.nexts.erase(node); | |||||
| } | |||||
| if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) { | |||||
| cnode_pres.erase(attach_node); | |||||
| } | |||||
| } | |||||
| to_be_through_pass.insert(node); | |||||
| } | |||||
| // Eliminate depend node of node relations. | |||||
| ProcessThroughPassCNode([&to_be_through_pass](const AnfNodePtr &node) { return to_be_through_pass.count(node) > 0; }, | |||||
| node_rels); | |||||
| } | |||||
| void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | ||||
| AnfNodePtrList latter_to_be_erased; | AnfNodePtrList latter_to_be_erased; | ||||
| for (auto &[node, node_rel] : (*node_rels)) { | for (auto &[node, node_rel] : (*node_rels)) { | ||||
| @@ -441,8 +414,7 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An | |||||
| auto prior_node = get_info(node); | auto prior_node = get_info(node); | ||||
| for (const auto &input : (node->cast<CNodePtr>())->inputs()) { | for (const auto &input : (node->cast<CNodePtr>())->inputs()) { | ||||
| // Parameter for ControlDepend when depend mode is 1. | |||||
| if (!input->isa<CNode>() && !input->isa<Parameter>()) { | |||||
| if (!input->isa<CNode>()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto behind_node = get_info(input); | auto behind_node = get_info(input); | ||||
| @@ -451,13 +423,11 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An | |||||
| } | } | ||||
| } | } | ||||
| ProcessDependCNode(&node_rels); | |||||
| ProcessThroughPassCNode( | ProcessThroughPassCNode( | ||||
| [](const AnfNodePtr &node) { | [](const AnfNodePtr &node) { | ||||
| return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); | return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); | ||||
| }, | }, | ||||
| &node_rels); | &node_rels); | ||||
| ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa<Parameter>(); }, &node_rels); | |||||
| ProcessTailMakeTupleCNode(&node_rels); | ProcessTailMakeTupleCNode(&node_rels); | ||||
| ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_); | ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_); | ||||
| @@ -707,51 +677,6 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa | |||||
| SetFusionInfoAttrToNode(attach_node, parallel_info); | SetFusionInfoAttrToNode(attach_node, parallel_info); | ||||
| } | } | ||||
| void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto mng = kernel_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(kernel_graph, true); | |||||
| kernel_graph->set_manager(mng); | |||||
| } | |||||
| const auto &users = mng->node_users()[node]; | |||||
| std::vector<std::pair<AnfNodePtr, int>> sons; | |||||
| for (const auto &[user, index] : users) { | |||||
| if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) { | |||||
| sons.emplace_back(user, index); | |||||
| continue; | |||||
| } | |||||
| auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin()); | |||||
| sons.emplace_back(fake_first_grad_son, grad_index); | |||||
| } | |||||
| AnfNodePtrList latter_to_delete; | |||||
| for (const auto &[son, index] : sons) { | |||||
| if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| latter_to_delete.push_back(son); | |||||
| } | |||||
| if (latter_to_delete.empty()) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin(); | |||||
| if (latter_to_delete.size() == sons.size()) { | |||||
| // Left one Depend node relation and delete others! | |||||
| ++delete_begin; | |||||
| } | |||||
| for (; delete_begin != latter_to_delete.end(); ++delete_begin) { | |||||
| auto depend_anfnode = *delete_begin; | |||||
| auto depend_cnode = depend_anfnode->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend); | |||||
| mng->Replace(depend_anfnode, depend_prior_node); | |||||
| } | |||||
| } | |||||
| void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) { | void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) { | ||||
| auto fusion_type = parallel_info.fusion_info()->FusionType(); | auto fusion_type = parallel_info.fusion_info()->FusionType(); | ||||
| AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node); | AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node); | ||||
| @@ -776,7 +701,6 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> | |||||
| SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); | SetFusedParallelOpAttrToReturnNode(parallel_infos[i]); | ||||
| AnfNodePtr sg_node; | AnfNodePtr sg_node; | ||||
| std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel"); | ||||
| PostProcessForNewSubGraphCNode(sg_node, kernel_graph); | |||||
| AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); | AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node); | ||||
| DumpParallelFusionDetail(fuse_nodes, sg_node); | DumpParallelFusionDetail(fuse_nodes, sg_node); | ||||
| } | } | ||||