From e5306b913d77924e422120c9367448c2d6f52784 Mon Sep 17 00:00:00 2001 From: dayschan Date: Thu, 3 Dec 2020 09:41:32 +0800 Subject: [PATCH] GraphKernel Fuser Refactor the BasicOpsFusion and CompositeOpsFusion to one pass. Add a pass to eliminate the redundant output. TODO: rename the file basic_ops_fusion and delete the file composite_ops_fusion --- .../graph_kernel/expanders/fused_adam.py | 6 +- .../expanders/fused_adam_weight_decay.py | 6 +- .../ascend/ascend_backend_optimization.cc | 2 - .../graph_kernel/basic_ops_fusion.cc | 238 ++++++------- .../graph_kernel/composite_ops_fusion.cc | 193 +---------- .../graph_kernel/composite_ops_fusion.h | 18 +- .../eliminate_redundant_output.cc | 327 ++++++++++++++++++ .../graph_kernel/eliminate_redundant_output.h | 36 ++ .../graph_kernel/graph_kernel_helper.cc | 19 +- .../graph_kernel/graph_kernel_helper.h | 6 +- .../ccsrc/backend/session/ascend_session.cc | 4 +- .../ccsrc/backend/session/gpu_session.cc | 4 +- 12 files changed, 489 insertions(+), 370 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.h diff --git a/mindspore/_extends/graph_kernel/expanders/fused_adam.py b/mindspore/_extends/graph_kernel/expanders/fused_adam.py index beea8b3ca0..e8db66c71f 100644 --- a/mindspore/_extends/graph_kernel/expanders/fused_adam.py +++ b/mindspore/_extends/graph_kernel/expanders/fused_adam.py @@ -61,11 +61,11 @@ def expand_fusedadam(expand_info): next_para = graph_builder.emit('Sub', [param, update_with_lr]) param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) - m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) - v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) + param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) + param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) # set graph output. - graph_scope.set_output(param_result, m_result, v_result) + graph_scope.set_output(param_result) graph = graph_builder.get()[0] return graph diff --git a/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py b/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py index d01170be60..772cadd0d5 100644 --- a/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +++ b/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py @@ -66,11 +66,11 @@ def expand_fusedadamweightdecay(expand_info): next_para = graph_builder.emit('Sub', [param, update_with_lr]) para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) - m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) - v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) + para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) + para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) # set graph output. - graph_scope.set_output(para_result, m_result, v_result) + graph_scope.set_output(para_result) graph = graph_builder.get()[0] return graph diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 78ef98b50e..2f3c57c665 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -117,8 +117,6 @@ #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h" #include "utils/ms_context.h" -#include "backend/optimizer/graph_kernel/composite_ops_fusion.h" -#include "backend/optimizer/graph_kernel/basic_ops_fusion.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index 0b9fa666ba..49cb4f52dc 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -31,132 +31,111 @@ #include "ir/func_graph_cloner.h" #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/pass/getitem_tuple.h" namespace mindspore { namespace opt { namespace { +bool IsFusibleOp(const AnfNodePtr &node) { +#if ENABLE_D + const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", + "LambNextMV", "LambUpdateWithLR"}; + if (AnfAlgo::IsGraphKernel(node)) { + auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_attr != nullptr) { + return graph_kernel_black_list.count(GetValue(fg_attr)) == 0; + } + } +#endif + return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); +} + IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { if (cur_node == node) { return FOLLOW; } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - bool is_fusable = IsBasicFuseOp(node); - return is_fusable ? FOLLOW : EXCLUDE; -} - -std::vector FindFuseCNodes(const CNodePtr &cnode, - const std::multimap> &dep_pri) { - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes, dep_pri, false); + if (IsFusibleOp(node)) { + return FOLLOW; } - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users, - std::vector *control_depend_nodes, std::vector *control_depend_use_index, - bool *is_only_control_depend_use, AnfNodePtr *use_out) { - for (auto &user : users) { - auto use_node = user.first; - if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - *is_only_control_depend_use = false; - continue; - } - if (outputs_set.count(use_node) != 0) { - *use_out = use_node; - } - if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { - control_depend_nodes->push_back(use_node->cast()); - control_depend_use_index->push_back(user.second); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + auto prev_node = node->cast()->input(kRealInputNodeIndexInTupleGetItem); + if (AnfAlgo::IsGraphKernel(prev_node)) { + return FOLLOW; } } + return EXCLUDE; } -bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, - std::unordered_map *eqv, - std::multimap> *depend_prior) { - AnfNodeSet outputs_set; - for (auto out : *outputs) { - outputs_set.insert(out); - } - bool has_erase_outs = false; - int index = -1; - for (auto it = outputs->begin(); it != outputs->end();) { - index++; - auto out = *it; - (*eqv)[out] = vir_outputs[IntToSize(index)]; - auto users = mng->node_users()[out]; - bool is_only_control_depend_use = true; - std::vector control_depend_use_index; - std::vector control_depend_nodes; - AnfNodePtr use_out = nullptr; - SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index, - &is_only_control_depend_use, &use_out); - if (is_only_control_depend_use && !control_depend_nodes.empty()) { - MS_EXCEPTION_IF_NULL(use_out); - it = outputs->erase(it); - for (size_t i = 0; i < control_depend_nodes.size(); ++i) { - auto control_depend_node = control_depend_nodes[i]; - std::vector new_control_depend_inputs; - for (size_t j = 0; j < control_depend_node->size(); ++j) { - if (j == control_depend_use_index[i]) { - new_control_depend_inputs.push_back(use_out); - } else { - new_control_depend_inputs.push_back(control_depend_node->input(j)); - } - } - auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); - mng->Replace(control_depend_node, new_control_depend); - has_erase_outs = true; - UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend); +// The GetItem node should be fused with its real input and users. +// If its real input is not in the fuse_list, the GetItem should be excluded. +AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { + if (fused_op.empty()) return AnfNodePtrList(); + std::set fused_op_set(fused_op.begin(), fused_op.end()); + auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; + + auto mng = fused_op[0]->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + bool changed = true; + while (changed) { + changed = false; + AnfNodePtrList remove_list; + for (auto getitem : fused_op_set) { + if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; + + // GetItem should be fused with its real input. + auto prev_node = getitem->cast()->input(kRealInputNodeIndexInTupleGetItem); + if (check_include(prev_node) == EXCLUDE) { + remove_list.push_back(getitem); + break; } - } else { - it++; - } - } - return has_erase_outs; -} -void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng, - std::multimap> *depend_prior) { - AnfNodePtrList vir_outputs; - std::unordered_map eqv; - auto fg_outputs = fg->output(); - if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { - auto cnode = fg_outputs->cast(); - for (size_t i = 1; i < cnode->size(); ++i) { - vir_outputs.push_back(cnode->input(i)); - } - } else { - vir_outputs.push_back(fg_outputs); - } + // GetItem should be fused with its all users. + const auto &users = mng->node_users()[getitem]; + if (std::any_of(users.begin(), users.end(), [check_include](const std::pair &user) { + return check_include(user.first) == EXCLUDE; + })) { + remove_list = DeepLinkedGraphSearch(getitem, check_include); + break; + } - if (vir_outputs.size() != outputs->size()) { - MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; + // To fix the issue of getitem-index, only support to fuse the previous node with its all users. + const auto &brothers = mng->node_users()[prev_node]; + if (std::any_of(brothers.begin(), brothers.end(), [check_include](const std::pair &user) { + return check_include(user.first) == EXCLUDE; + })) { + remove_list = DeepLinkedGraphSearch(getitem, check_include); + break; + } + } + if (!remove_list.empty()) { + for (auto node : remove_list) { + fused_op_set.erase(node); + } + changed = true; + } } - if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) { - return; + // keep the original order of fused_op. + AnfNodePtrList result; + for (auto node : fused_op) { + if (fused_op_set.count(node)) { + result.push_back(node); + } } + return result; +} - AnfNodePtr fg_new_output; - if (outputs->size() > 1) { - std::vector output_args; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); - // Set output for AnfGraph - fg_new_output = fg->NewCNode(output_args); - } else { - fg_new_output = eqv[(*outputs)[0]]; +std::vector FindFuseCNodes(const CNodePtr &cnode, + const std::multimap> &dep_pri) { + // Search fusable nodes according input direction. + auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); + auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); + if (used_nodes.size() > 1) { + used_nodes = RemoveCircle(used_nodes, dep_pri); } - fg->set_output(fg_new_output, true); + used_nodes = RemoveWildGetitem(used_nodes); + TopoSortForNodeList(&used_nodes); + return used_nodes; } bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector &todos, @@ -170,14 +149,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vectorcast(); - if (node == nullptr) { - continue; - } - if (fused_ops->count(node)) { + if (node == nullptr || fused_ops->count(node)) { continue; } - bool is_basic_op = IsBasicFuseOp(node); - if (!is_basic_op || !kernel_graph->nodes().contains(node)) { + bool is_fusible_op = IsFusibleOp(node); + if (!is_fusible_op || !kernel_graph->nodes().contains(node)) { continue; } @@ -185,26 +161,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vectorname() + "_"; - } fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); + AnfNodePtr fused_new_node; + AnfNodePtrList old_outputs; + std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion"); + ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); } std::dynamic_pointer_cast(kernel_graph)->SetExecOrderByDefault(); return changed; @@ -224,6 +186,22 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph) { return FuseBasicOps(kernel_graph, todos, &fused_ops); } -bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); } +void EliminateGetitem(const FuncGraphPtr &func_graph) { + std::shared_ptr eliminate_getitem_pass = std::make_shared(); + auto todos = TopoSort(func_graph->get_return()); + for (auto node : todos) { + if (AnfAlgo::IsGraphKernel(node)) { + eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); + } + } +} + +bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { + bool changed = FuseBasicOps(func_graph); + if (changed) { + EliminateGetitem(func_graph); + } + return changed; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc index d2d9098c1e..1944bfffad 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc @@ -50,53 +50,8 @@ std::vector DeepLinkedGraphSearch(const std::vector &roo } return inputs; } - -std::vector DeepUsersSearch(const std::vector &roots, const IncludeFunc &include, - const FuncGraphManagerPtr &mng) { - std::vector users; - for (auto &root : roots) { - auto tmp = DeepUsersSearch(root, include, mng); - users.insert(users.end(), tmp.begin(), tmp.end()); - } - return users; -} } // namespace -IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) { - return FOLLOW; - } - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - auto prev_node = node->cast()->input(kRealInputNodeIndexInTupleGetItem); - if (AnfAlgo::IsGraphKernel(prev_node)) { - return FOLLOW; - } - } - return EXCLUDE; -} - -IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (AnfAlgo::IsGraphKernel(node)) { - auto cnode = node->cast(); - auto fg = GetValueNode(cnode->input(kAnfPrimitiveIndex)); - auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - MS_EXCEPTION_IF_NULL(fg_attr_val); - auto fg_attr = GetValue(fg_attr_val); - if (fg_attr == kApplyMomentumOpName) { - return FOLLOW; - } - return EXCLUDE; - } - bool is_fusable = IsBasicFuseOp(node); - return is_fusable ? FOLLOW : EXCLUDE; -} - bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, std::set *cached_unconnected_set, std::vector *circle_nodes, const std::multimap> &depend_prior) { @@ -163,9 +118,8 @@ bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &che return !circle_nodes->empty(); } -std::vector RemoveCircle(const std::vector &fused_op, - const std::multimap> &depend_prior, - bool is_backward) { +AnfNodePtrList RemoveCircle(const std::vector &fused_op, + const std::multimap> &depend_prior) { std::set cached_unconnected_set; std::set fused_op_set(fused_op.begin(), fused_op.end()); auto include = [&fused_op_set](const AnfNodePtr &node) { @@ -181,13 +135,8 @@ std::vector RemoveCircle(const std::vector &fused_op, bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior); // delete the circle node and the node which depend on the circle node in fused op if (has_circle) { - auto mng = (*iter)->func_graph()->manager(); std::vector erase_nodes; - if (is_backward) { - erase_nodes = DeepUsersSearch(circle_nodes, include, mng); - } else { - erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); - } + erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); for (auto erase_node : erase_nodes) { fused_op_set.erase(erase_node); } @@ -203,60 +152,6 @@ std::vector RemoveCircle(const std::vector &fused_op, return res; } -// The GetItem node should be fused with its real input and users. -// If its real input is not in the fuse_list, the GetItem should be excluded. -AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { - if (fused_op.empty()) return AnfNodePtrList(); - std::set fused_op_set(fused_op.begin(), fused_op.end()); - auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; - - auto mng = fused_op[0]->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mng); - bool changed = true; - while (changed) { - changed = false; - AnfNodePtrList remove_list; - for (auto node : fused_op_set) { - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue; - // GetItem should be fused with its real input. - auto prev_node = node->cast()->input(kRealInputNodeIndexInTupleGetItem); - if (check_include(prev_node) == EXCLUDE) { - remove_list.push_back(node); - break; - } - - // GetItem should be fused with its all users. - auto &users = mng->node_users()[node]; - bool outside_user_found = false; - for (auto iter = users.begin(); iter != users.end(); ++iter) { - if (check_include(iter->first) == EXCLUDE) { - outside_user_found = true; - break; - } - } - if (outside_user_found) { - remove_list = DeepUsersSearch(node, check_include, mng); - break; - } - } - if (!remove_list.empty()) { - for (auto node : remove_list) { - fused_op_set.erase(node); - } - changed = true; - } - } - - // keep the original order of fused_op. - AnfNodePtrList result; - for (auto node : fused_op) { - if (fused_op_set.count(node)) { - result.push_back(node); - } - } - return result; -} - void TopoSortForNodeList(std::vector *lst) { if (lst->size() < 2) { return; @@ -310,87 +205,5 @@ void TopoSortForNodeList(std::vector *lst) { lst->assign(res.begin(), res.end()); } - -std::vector FindFuseCNodes(const CNodePtr &cnode, - const std::multimap> &dep_pri) { - auto func_graph = cnode->func_graph(); - auto mng = func_graph->manager(); - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - std::reverse(used_nodes.begin(), used_nodes.end()); - // Search fusable nodes according output direction. - auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1); - auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); - - used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes, dep_pri); - } - used_nodes = RemoveWildGetitem(used_nodes); - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -bool FuseCompositeOps(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - auto todos = TopoSort(kernel_graph->get_return()); - std::reverse(todos.begin(), todos.end()); - - std::multimap> depend_prior; - InitDependPrior(todos, &depend_prior); - - bool changed = false; - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto node = *iter; - if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { - continue; - } - - auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_attr != nullptr) { - auto fg_name = GetValue(fg_attr); - if (graph_kernel_black_list.count(fg_name) != 0) { - continue; - } - } - - auto fuse_nodes = FindFuseCNodes(node->cast(), depend_prior); - if (fuse_nodes.size() <= 1) { - continue; - } - changed = true; - - AnfNodePtr fused_new_node; - AnfNodePtrList old_outputs; - std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); - ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); - } - return changed; -} - -void EliminateGetItem(const FuncGraphPtr &func_graph) { - std::shared_ptr eliminate_getitem_pass = std::make_shared(); - auto todos = TopoSort(func_graph->get_return()); - for (auto node : todos) { - if (AnfAlgo::IsGraphKernel(node)) { - eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); - } - } -} - -bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { - auto changed = FuseCompositeOps(std::dynamic_pointer_cast(func_graph)); - if (changed) { - EliminateGetItem(func_graph); - } - return changed; -} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h index e669a37692..bb1f31221b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h @@ -28,24 +28,10 @@ namespace mindspore { namespace opt { -const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", - "LambNextMV", "LambUpdateWithLR"}; - -std::vector RemoveCircle(const std::vector &fused_op, - const std::multimap> &depend_prior, - bool is_backward = true); +AnfNodePtrList RemoveCircle(const std::vector &fused_op, + const std::multimap> &depend_prior); void TopoSortForNodeList(std::vector *lst); - -bool FuseCompositeOps(const std::shared_ptr &kernel_graph); - -class CompositeOpsFusion : public Pass { - public: - CompositeOpsFusion() : Pass("composite_ops_fusion") {} - ~CompositeOpsFusion() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -using FuseGraphKernelPassPtr = std::shared_ptr; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc new file mode 100644 index 0000000000..0605411fe1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc @@ -0,0 +1,327 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" + +#include +#include +#include +#include +#include + +#include "base/core_ops.h" +#include "ir/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "debug/anf_ir_dump.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" + +namespace mindspore { +namespace opt { +namespace { +inline size_t GetIndex(const AnfNodePtr &getitem_node) { + MS_EXCEPTION_IF_NULL(getitem_node); + if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) { + MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope(); + } + return LongToSize(GetValue( + getitem_node->cast()->input(kInputNodeOutputIndexInTupleGetItem)->cast()->value())); +} + +bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list, + bool merge_repeated_getitem = false) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(getitem_list); + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto output = func_graph->output(); + if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope(); + } + auto output_num = output->cast()->size() - 1; + getitem_list->clear(); + getitem_list->resize(output_num, nullptr); + const auto &users = mng->node_users()[node]; + bool changed = false; + AnfNodePtrList user_nodes; + std::transform(users.begin(), users.end(), std::back_inserter(user_nodes), + [](const std::pair &user) { return user.first; }); + for (const auto &getitem : user_nodes) { + MS_EXCEPTION_IF_NULL(getitem); + auto idx = GetIndex(getitem); + if (idx >= output_num) { + MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); + } + if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) { + mng->Replace(getitem, (*getitem_list)[idx]); + changed = true; + } else { + (*getitem_list)[idx] = getitem; + } + } + return changed; +} + +AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) { + auto todos = TopoSort(func_graph->get_return()); + AnfNodePtrList result; + std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) { + return AnfAlgo::IsGraphKernel(node) && + IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple); + }); + return result; +} + +/* Merge the get_item nodes that have same index. + * %1 = call @graph_kernel(p1, p2) + * %2 = tuple_getitem(%1, 0) + * %3 = tuple_getitem(%1, 0) + * %4 = tuple_getitem(%1, 1) + * %5 = user_x(%2) + * %6 = user_y(%3) + * %7 = user_z(%4) + * ---> + * %1 = call @graph_kernel(p1, p2) + * %2 = tuple_getitem(%1, 0) + * %3 = tuple_getitem(%1, 1) + * %4 = user_x(%2) + * %5 = user_y(%2) + * %6 = user_z(%3) + */ +class MergeRepeatedGetitem : public Pass { + public: + bool Run(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto todos = FindGraphKernelsWithMultiOutput(func_graph); + bool changed = false; + for (auto node : todos) { + AnfNodePtrList getitem_list; + changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed; + } + return changed; + } +}; + +/* Merge the get_item nodes that have same index. + * subgraph graph_kernel(%para1, %para2) + * %1 = TensorAdd(%para1, %para2) + * %2 = Neg(%1) + * %3 = make_tuple(%1, %2) + * return (%3) + * %1 = call @graph_kernel(%p1, %p2) + * %2 = tuple_getitem(%1, 0) + * %3 = tuple_getitem(%1, 1) + * %4 = ControlDepend(%0, %2) + * %5 = other_user(%3) + * ---> + * subgraph graph_kernel(%para1, %para2) + * %1 = TensorAdd(%para1, %para2) + * %2 = Neg(%1) + * %3 = make_tuple(%1, %2) + * return (%3) + * %1 = call @graph_kernel(%p1, %p2) + * %3 = tuple_getitem(%1, 1) + * %4 = ControlDepend(%0, %3) + * %5 = other_user(%3) + * + * Then the output 0 can be eliminate in the later pass. + */ +class EliminateGetitemForControlDepend : public Pass { + public: + bool Run(const FuncGraphPtr &func_graph) { + auto todos = FindGraphKernelsWithMultiOutput(func_graph); + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + bool changed = false; + for (const auto &node : todos) { + getitems_.clear(); + GetGraphKernelGetitemList(mng, node, &getitems_, false); + if (getitems_.empty()) continue; + indexes_.clear(); + GetIndexesToControlDepend(mng); + FilterRedundantOutputs(node); + if (indexes_.empty()) continue; + size_t index = GetFinalIndex(node); + changed = ReplaceGetitems(mng, index) || changed; + } + return changed; + } + + private: + AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs. + std::vector indexes_; // Indexes of MakeTuple to be eliminated. + + bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) { + MS_EXCEPTION_IF_NULL(getitems_[index]); + bool changed = false; + for (auto i : indexes_) { + if (i != index) { + MS_EXCEPTION_IF_NULL(getitems_[i]); + mng->Replace(getitems_[i], getitems_[index]); + changed = true; + } + } + return changed; + } + + // Find the redundant output index. + // the real output should have multiple users. + void FilterRedundantOutputs(const AnfNodePtr &node) { + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto &users = mng->node_users(); + auto maketuple = func_graph->output()->cast(); + MS_EXCEPTION_IF_NULL(maketuple); + std::vector result; + for (auto i : indexes_) { + auto real_output = maketuple->input(i); + if (users[real_output].size() > 1) { + result.push_back(i); + } + } + indexes_ = std::move(result); + } + + // Get the nodes that only have ControlDepend users. + void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) { + for (size_t i = 0; i < getitems_.size(); ++i) { + const AnfNodePtr &getitem = getitems_[i]; + if (getitem == nullptr) { + continue; + } + const auto &getitem_user = mng->node_users()[getitem]; + if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair &user) { + return IsPrimitiveCNode(user.first, prim::kPrimControlDepend); + })) { + indexes_.push_back(i); + } + } + } + + size_t GetFinalIndex(const AnfNodePtr &node) { + auto is_redundant_index = [this](size_t i) { + return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end(); + }; + for (size_t i = 0; i < getitems_.size(); ++i) { + if (getitems_[i] != nullptr && !is_redundant_index(i)) { + return i; + } + } + return indexes_[0]; + } +}; +} // namespace + +// Remove the output without user or with virtual user (like ControlDepend) +bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + + bool changed = std::make_shared()->Run(func_graph); + changed = std::make_shared()->Run(func_graph) || changed; + changed = Process(func_graph) || changed; + return changed; +} + +void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) { + if (offset == 0) return; + MS_EXCEPTION_IF_NULL(getitem); + int64_t index = SizeToLong(GetIndex(getitem)); + if (offset > index) { + MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString(); + } + index -= offset; + auto idx_node = NewValueNode(MakeValue(index)); + auto abstract = std::make_shared(std::make_shared(index)); + idx_node->set_abstract(abstract); + idx_node->set_kernel_info(std::make_shared()); + getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node); +} + +AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) { + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto old_maketuple = func_graph->output()->cast(); + MS_EXCEPTION_IF_NULL(old_maketuple); + AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + AbstractBasePtrList abstract_list; + int64_t offset = 0; + for (size_t i = 0; i < getitems.size(); ++i) { + if (getitems[i] == nullptr) { + offset++; + } else { + new_maketuple_inputs.push_back(old_maketuple->input(i + 1)); + abstract_list.push_back(old_maketuple->input(i + 1)->abstract()); + UpdateGetitemIndex(getitems[i]->cast(), offset); + } + } + if (offset == 0) return nullptr; + if (new_maketuple_inputs.size() == 1) { + MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty"; + } + if (new_maketuple_inputs.size() == 2) { + func_graph->set_output(new_maketuple_inputs.back()); + } else { + auto make_tuple = func_graph->NewCNode(new_maketuple_inputs); + make_tuple->set_abstract(std::make_shared(abstract_list)); + make_tuple->set_kernel_info(std::make_shared()); + func_graph->set_output(make_tuple); + } + + auto old_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(old_cnode); + AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end()); + AnfNodePtrList outputs; + kernel::GetFuncGraphOutputNodes(func_graph, &outputs); + auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs); + SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); + return graph_kernel_node; +} + +bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto todos = FindGraphKernelsWithMultiOutput(func_graph); + bool changed = false; + for (auto node : todos) { + AnfNodePtrList getitems; + GetGraphKernelGetitemList(mng, node, &getitems, false); + auto new_node = ReplaceMakeTuple(node, getitems); + if (new_node != nullptr) { + if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) { + // only one output, remove the getitem. + auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; }); + if (i != getitems.end()) { + mng->Replace(*i, new_node); + } + } else { + mng->Replace(node, new_node); + } + changed = true; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.h new file mode 100644 index 0000000000..81cab0c1f3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class EliminateRedundantOutput : public Pass { + public: + EliminateRedundantOutput() : Pass("eliminate_redundant_output") {} + ~EliminateRedundantOutput() override = default; + bool Run(const FuncGraphPtr &func_graph) override; + + private: + bool Process(const FuncGraphPtr &func_graph); + void UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset); + AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index c365943dec..7bd937ec63 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -531,7 +531,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f } std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, - const std::shared_ptr &kernel_graph, + const FuncGraphPtr &kernel_graph, const std::string &postfix) { auto mng = kernel_graph->manager(); if (mng == nullptr) { @@ -861,23 +861,6 @@ void InitDependPrior(const std::vector &todos, } } -void UpdateControlDependNode(std::multimap> *depend_prior, - const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) { - for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) { - if (iter->second.second == control_depend_node) { - iter = depend_prior->erase(iter); - continue; - } - ++iter; - } - - std::multimap> new_depend_prior; - InitDependPrior(std::vector{new_control_depend}, &new_depend_prior); - for (auto item : new_depend_prior) { - depend_prior->insert(item); - } -} - void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { std::multimap> new_fuse_cnode_dep_pri; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 30974acb95..d04d4c5aae 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -65,8 +65,8 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, - const std::shared_ptr &kernel_graph, - const std::string &postfix); + const FuncGraphPtr &kernel_graph, + const std::string &postfix = ""); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map); @@ -79,8 +79,6 @@ bool IsBasicFuseOp(const AnfNodePtr &node); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); void InitDependPrior(const std::vector &todos, std::multimap> *depend_prior); -void UpdateControlDependNode(std::multimap> *depend_prior, - const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend); void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 35ad2ca1e2..563b0eee69 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -42,7 +42,7 @@ #include "debug/data_dump/dump_json_parser.h" #include "debug/tensor_load.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" -#include "backend/optimizer/graph_kernel/composite_ops_fusion.h" +#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" @@ -822,7 +822,7 @@ void AscendSession::GraphKernelOptimize(const std::shared_ptr &kern auto pm = std::make_shared("graph_kernel_pm"); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index d70a570d4f..599345dcf8 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -38,7 +38,7 @@ #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" -#include "backend/optimizer/graph_kernel/composite_ops_fusion.h" +#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" @@ -171,7 +171,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(black_list)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(black_list));