From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doutags/v1.1.0
| @@ -61,11 +61,11 @@ def expand_fusedadam(expand_info): | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||
| param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) | |||
| param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(param_result, m_result, v_result) | |||
| graph_scope.set_output(param_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -66,11 +66,11 @@ def expand_fusedadamweightdecay(expand_info): | |||
| next_para = graph_builder.emit('Sub', [param, update_with_lr]) | |||
| para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) | |||
| m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) | |||
| v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) | |||
| para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) | |||
| para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) | |||
| # set graph output. | |||
| graph_scope.set_output(para_result, m_result, v_result) | |||
| graph_scope.set_output(para_result) | |||
| graph = graph_builder.get()[0] | |||
| return graph | |||
| @@ -117,8 +117,6 @@ | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include "utils/config_manager.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/dump_proto.h" | |||
| @@ -31,132 +31,111 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsFusibleOp(const AnfNodePtr &node) { | |||
| #if ENABLE_D | |||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | |||
| "LambNextMV", "LambUpdateWithLR"}; | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_attr != nullptr) { | |||
| return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0; | |||
| } | |||
| } | |||
| #endif | |||
| return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); | |||
| } | |||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||
| if (cur_node == node) { | |||
| return FOLLOW; | |||
| } | |||
| if (!IsPrimitiveCNode(node)) { | |||
| return EXCLUDE; | |||
| } | |||
| bool is_fusable = IsBasicFuseOp(node); | |||
| return is_fusable ? FOLLOW : EXCLUDE; | |||
| } | |||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) { | |||
| // Search fusable nodes according input direction. | |||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | |||
| if (used_nodes.size() > 1) { | |||
| used_nodes = RemoveCircle(used_nodes, dep_pri, false); | |||
| if (IsFusibleOp(node)) { | |||
| return FOLLOW; | |||
| } | |||
| TopoSortForNodeList(&used_nodes); | |||
| return used_nodes; | |||
| } | |||
| void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users, | |||
| std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index, | |||
| bool *is_only_control_depend_use, AnfNodePtr *use_out) { | |||
| for (auto &user : users) { | |||
| auto use_node = user.first; | |||
| if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { | |||
| *is_only_control_depend_use = false; | |||
| continue; | |||
| } | |||
| if (outputs_set.count(use_node) != 0) { | |||
| *use_out = use_node; | |||
| } | |||
| if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { | |||
| control_depend_nodes->push_back(use_node->cast<CNodePtr>()); | |||
| control_depend_use_index->push_back(user.second); | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (AnfAlgo::IsGraphKernel(prev_node)) { | |||
| return FOLLOW; | |||
| } | |||
| } | |||
| return EXCLUDE; | |||
| } | |||
| bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) { | |||
| AnfNodeSet outputs_set; | |||
| for (auto out : *outputs) { | |||
| outputs_set.insert(out); | |||
| } | |||
| bool has_erase_outs = false; | |||
| int index = -1; | |||
| for (auto it = outputs->begin(); it != outputs->end();) { | |||
| index++; | |||
| auto out = *it; | |||
| (*eqv)[out] = vir_outputs[IntToSize(index)]; | |||
| auto users = mng->node_users()[out]; | |||
| bool is_only_control_depend_use = true; | |||
| std::vector<size_t> control_depend_use_index; | |||
| std::vector<CNodePtr> control_depend_nodes; | |||
| AnfNodePtr use_out = nullptr; | |||
| SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index, | |||
| &is_only_control_depend_use, &use_out); | |||
| if (is_only_control_depend_use && !control_depend_nodes.empty()) { | |||
| MS_EXCEPTION_IF_NULL(use_out); | |||
| it = outputs->erase(it); | |||
| for (size_t i = 0; i < control_depend_nodes.size(); ++i) { | |||
| auto control_depend_node = control_depend_nodes[i]; | |||
| std::vector<AnfNodePtr> new_control_depend_inputs; | |||
| for (size_t j = 0; j < control_depend_node->size(); ++j) { | |||
| if (j == control_depend_use_index[i]) { | |||
| new_control_depend_inputs.push_back(use_out); | |||
| } else { | |||
| new_control_depend_inputs.push_back(control_depend_node->input(j)); | |||
| } | |||
| } | |||
| auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); | |||
| mng->Replace(control_depend_node, new_control_depend); | |||
| has_erase_outs = true; | |||
| UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend); | |||
| // The GetItem node should be fused with its real input and users. | |||
| // If its real input is not in the fuse_list, the GetItem should be excluded. | |||
| AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { | |||
| if (fused_op.empty()) return AnfNodePtrList(); | |||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | |||
| auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; | |||
| auto mng = fused_op[0]->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool changed = true; | |||
| while (changed) { | |||
| changed = false; | |||
| AnfNodePtrList remove_list; | |||
| for (auto getitem : fused_op_set) { | |||
| if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; | |||
| // GetItem should be fused with its real input. | |||
| auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (check_include(prev_node) == EXCLUDE) { | |||
| remove_list.push_back(getitem); | |||
| break; | |||
| } | |||
| } else { | |||
| it++; | |||
| } | |||
| } | |||
| return has_erase_outs; | |||
| } | |||
| void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) { | |||
| AnfNodePtrList vir_outputs; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; | |||
| auto fg_outputs = fg->output(); | |||
| if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { | |||
| auto cnode = fg_outputs->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| vir_outputs.push_back(cnode->input(i)); | |||
| } | |||
| } else { | |||
| vir_outputs.push_back(fg_outputs); | |||
| } | |||
| // GetItem should be fused with its all users. | |||
| const auto &users = mng->node_users()[getitem]; | |||
| if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) { | |||
| return check_include(user.first) == EXCLUDE; | |||
| })) { | |||
| remove_list = DeepLinkedGraphSearch(getitem, check_include); | |||
| break; | |||
| } | |||
| if (vir_outputs.size() != outputs->size()) { | |||
| MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; | |||
| // To fix the issue of getitem-index, only support to fuse the previous node with its all users. | |||
| const auto &brothers = mng->node_users()[prev_node]; | |||
| if (std::any_of(brothers.begin(), brothers.end(), [check_include](const std::pair<AnfNodePtr, int> &user) { | |||
| return check_include(user.first) == EXCLUDE; | |||
| })) { | |||
| remove_list = DeepLinkedGraphSearch(getitem, check_include); | |||
| break; | |||
| } | |||
| } | |||
| if (!remove_list.empty()) { | |||
| for (auto node : remove_list) { | |||
| fused_op_set.erase(node); | |||
| } | |||
| changed = true; | |||
| } | |||
| } | |||
| if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) { | |||
| return; | |||
| // keep the original order of fused_op. | |||
| AnfNodePtrList result; | |||
| for (auto node : fused_op) { | |||
| if (fused_op_set.count(node)) { | |||
| result.push_back(node); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| AnfNodePtr fg_new_output; | |||
| if (outputs->size() > 1) { | |||
| std::vector<AnfNodePtr> output_args; | |||
| output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), | |||
| [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); | |||
| // Set output for AnfGraph | |||
| fg_new_output = fg->NewCNode(output_args); | |||
| } else { | |||
| fg_new_output = eqv[(*outputs)[0]]; | |||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) { | |||
| // Search fusable nodes according input direction. | |||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | |||
| if (used_nodes.size() > 1) { | |||
| used_nodes = RemoveCircle(used_nodes, dep_pri); | |||
| } | |||
| fg->set_output(fg_new_output, true); | |||
| used_nodes = RemoveWildGetitem(used_nodes); | |||
| TopoSortForNodeList(&used_nodes); | |||
| return used_nodes; | |||
| } | |||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos, | |||
| @@ -170,14 +149,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||
| auto node = (*iter)->cast<CNodePtr>(); | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| if (fused_ops->count(node)) { | |||
| if (node == nullptr || fused_ops->count(node)) { | |||
| continue; | |||
| } | |||
| bool is_basic_op = IsBasicFuseOp(node); | |||
| if (!is_basic_op || !kernel_graph->nodes().contains(node)) { | |||
| bool is_fusible_op = IsFusibleOp(node); | |||
| if (!is_fusible_op || !kernel_graph->nodes().contains(node)) { | |||
| continue; | |||
| } | |||
| @@ -185,26 +161,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| if (fuse_nodes.size() <= 1) { | |||
| continue; | |||
| } | |||
| changed = true; | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); | |||
| RemoveControlDependOut(fg, &outputs, mng, &depend_prior); | |||
| ConvertNonscalarTensorToParameter(fg, &inputs); | |||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); | |||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); | |||
| ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs); | |||
| ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); | |||
| // Set graph kernel attr | |||
| std::string fuse_op_name = ""; | |||
| for (auto &fuse_node : fuse_nodes) { | |||
| fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; | |||
| } | |||
| fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); | |||
| fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); | |||
| AnfNodePtr fused_new_node; | |||
| AnfNodePtrList old_outputs; | |||
| std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion"); | |||
| ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); | |||
| } | |||
| std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault(); | |||
| return changed; | |||
| @@ -224,6 +186,22 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph) { | |||
| return FuseBasicOps(kernel_graph, todos, &fused_ops); | |||
| } | |||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); } | |||
| void EliminateGetitem(const FuncGraphPtr &func_graph) { | |||
| std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| for (auto node : todos) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); | |||
| } | |||
| } | |||
| } | |||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { | |||
| bool changed = FuseBasicOps(func_graph); | |||
| if (changed) { | |||
| EliminateGetitem(func_graph); | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -50,53 +50,8 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roo | |||
| } | |||
| return inputs; | |||
| } | |||
| std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include, | |||
| const FuncGraphManagerPtr &mng) { | |||
| std::vector<AnfNodePtr> users; | |||
| for (auto &root : roots) { | |||
| auto tmp = DeepUsersSearch(root, include, mng); | |||
| users.insert(users.end(), tmp.begin(), tmp.end()); | |||
| } | |||
| return users; | |||
| } | |||
| } // namespace | |||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||
| if (cur_node == node) { | |||
| return FOLLOW; | |||
| } | |||
| if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) { | |||
| return FOLLOW; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (AnfAlgo::IsGraphKernel(prev_node)) { | |||
| return FOLLOW; | |||
| } | |||
| } | |||
| return EXCLUDE; | |||
| } | |||
| IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||
| if (cur_node == node) { | |||
| return FOLLOW; | |||
| } | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex)); | |||
| auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| MS_EXCEPTION_IF_NULL(fg_attr_val); | |||
| auto fg_attr = GetValue<std::string>(fg_attr_val); | |||
| if (fg_attr == kApplyMomentumOpName) { | |||
| return FOLLOW; | |||
| } | |||
| return EXCLUDE; | |||
| } | |||
| bool is_fusable = IsBasicFuseOp(node); | |||
| return is_fusable ? FOLLOW : EXCLUDE; | |||
| } | |||
| bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, | |||
| std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) { | |||
| @@ -163,9 +118,8 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che | |||
| return !circle_nodes->empty(); | |||
| } | |||
| std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior, | |||
| bool is_backward) { | |||
| AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) { | |||
| std::set<AnfNodePtr> cached_unconnected_set; | |||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | |||
| auto include = [&fused_op_set](const AnfNodePtr &node) { | |||
| @@ -181,13 +135,8 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior); | |||
| // delete the circle node and the node which depend on the circle node in fused op | |||
| if (has_circle) { | |||
| auto mng = (*iter)->func_graph()->manager(); | |||
| std::vector<AnfNodePtr> erase_nodes; | |||
| if (is_backward) { | |||
| erase_nodes = DeepUsersSearch(circle_nodes, include, mng); | |||
| } else { | |||
| erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); | |||
| } | |||
| erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); | |||
| for (auto erase_node : erase_nodes) { | |||
| fused_op_set.erase(erase_node); | |||
| } | |||
| @@ -203,60 +152,6 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| return res; | |||
| } | |||
| // The GetItem node should be fused with its real input and users. | |||
| // If its real input is not in the fuse_list, the GetItem should be excluded. | |||
| AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { | |||
| if (fused_op.empty()) return AnfNodePtrList(); | |||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | |||
| auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; | |||
| auto mng = fused_op[0]->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool changed = true; | |||
| while (changed) { | |||
| changed = false; | |||
| AnfNodePtrList remove_list; | |||
| for (auto node : fused_op_set) { | |||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue; | |||
| // GetItem should be fused with its real input. | |||
| auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (check_include(prev_node) == EXCLUDE) { | |||
| remove_list.push_back(node); | |||
| break; | |||
| } | |||
| // GetItem should be fused with its all users. | |||
| auto &users = mng->node_users()[node]; | |||
| bool outside_user_found = false; | |||
| for (auto iter = users.begin(); iter != users.end(); ++iter) { | |||
| if (check_include(iter->first) == EXCLUDE) { | |||
| outside_user_found = true; | |||
| break; | |||
| } | |||
| } | |||
| if (outside_user_found) { | |||
| remove_list = DeepUsersSearch(node, check_include, mng); | |||
| break; | |||
| } | |||
| } | |||
| if (!remove_list.empty()) { | |||
| for (auto node : remove_list) { | |||
| fused_op_set.erase(node); | |||
| } | |||
| changed = true; | |||
| } | |||
| } | |||
| // keep the original order of fused_op. | |||
| AnfNodePtrList result; | |||
| for (auto node : fused_op) { | |||
| if (fused_op_set.count(node)) { | |||
| result.push_back(node); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) { | |||
| if (lst->size() < 2) { | |||
| return; | |||
| @@ -310,87 +205,5 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) { | |||
| lst->assign(res.begin(), res.end()); | |||
| } | |||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) { | |||
| auto func_graph = cnode->func_graph(); | |||
| auto mng = func_graph->manager(); | |||
| // Search fusable nodes according input direction. | |||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | |||
| std::reverse(used_nodes.begin(), used_nodes.end()); | |||
| // Search fusable nodes according output direction. | |||
| auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1); | |||
| auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); | |||
| used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); | |||
| if (used_nodes.size() > 1) { | |||
| used_nodes = RemoveCircle(used_nodes, dep_pri); | |||
| } | |||
| used_nodes = RemoveWildGetitem(used_nodes); | |||
| TopoSortForNodeList(&used_nodes); | |||
| return used_nodes; | |||
| } | |||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto mng = kernel_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(kernel_graph, true); | |||
| kernel_graph->set_manager(mng); | |||
| } | |||
| auto todos = TopoSort(kernel_graph->get_return()); | |||
| std::reverse(todos.begin(), todos.end()); | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; | |||
| InitDependPrior(todos, &depend_prior); | |||
| bool changed = false; | |||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||
| auto node = *iter; | |||
| if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { | |||
| continue; | |||
| } | |||
| auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_attr != nullptr) { | |||
| auto fg_name = GetValue<std::string>(fg_attr); | |||
| if (graph_kernel_black_list.count(fg_name) != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior); | |||
| if (fuse_nodes.size() <= 1) { | |||
| continue; | |||
| } | |||
| changed = true; | |||
| AnfNodePtr fused_new_node; | |||
| AnfNodePtrList old_outputs; | |||
| std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); | |||
| ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); | |||
| } | |||
| return changed; | |||
| } | |||
| void EliminateGetItem(const FuncGraphPtr &func_graph) { | |||
| std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| for (auto node : todos) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); | |||
| } | |||
| } | |||
| } | |||
| bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { | |||
| auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph)); | |||
| if (changed) { | |||
| EliminateGetItem(func_graph); | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -28,24 +28,10 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | |||
| "LambNextMV", "LambUpdateWithLR"}; | |||
| std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior, | |||
| bool is_backward = true); | |||
| AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior); | |||
| void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); | |||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| class CompositeOpsFusion : public Pass { | |||
| public: | |||
| CompositeOpsFusion() : Pass("composite_ops_fusion") {} | |||
| ~CompositeOpsFusion() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| }; | |||
| using FuseGraphKernelPassPtr = std::shared_ptr<CompositeOpsFusion>; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ | |||
| @@ -0,0 +1,327 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "base/core_ops.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| inline size_t GetIndex(const AnfNodePtr &getitem_node) { | |||
| MS_EXCEPTION_IF_NULL(getitem_node); | |||
| if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) { | |||
| MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope(); | |||
| } | |||
| return LongToSize(GetValue<int64_t>( | |||
| getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value())); | |||
| } | |||
| bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list, | |||
| bool merge_repeated_getitem = false) { | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| MS_EXCEPTION_IF_NULL(getitem_list); | |||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto output = func_graph->output(); | |||
| if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||
| MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope(); | |||
| } | |||
| auto output_num = output->cast<CNodePtr>()->size() - 1; | |||
| getitem_list->clear(); | |||
| getitem_list->resize(output_num, nullptr); | |||
| const auto &users = mng->node_users()[node]; | |||
| bool changed = false; | |||
| AnfNodePtrList user_nodes; | |||
| std::transform(users.begin(), users.end(), std::back_inserter(user_nodes), | |||
| [](const std::pair<AnfNodePtr, int> &user) { return user.first; }); | |||
| for (const auto &getitem : user_nodes) { | |||
| MS_EXCEPTION_IF_NULL(getitem); | |||
| auto idx = GetIndex(getitem); | |||
| if (idx >= output_num) { | |||
| MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); | |||
| } | |||
| if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) { | |||
| mng->Replace(getitem, (*getitem_list)[idx]); | |||
| changed = true; | |||
| } else { | |||
| (*getitem_list)[idx] = getitem; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) { | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| AnfNodePtrList result; | |||
| std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) { | |||
| return AnfAlgo::IsGraphKernel(node) && | |||
| IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple); | |||
| }); | |||
| return result; | |||
| } | |||
| /* Merge the get_item nodes that have same index. | |||
| * %1 = call @graph_kernel(p1, p2) | |||
| * %2 = tuple_getitem(%1, 0) | |||
| * %3 = tuple_getitem(%1, 0) | |||
| * %4 = tuple_getitem(%1, 1) | |||
| * %5 = user_x(%2) | |||
| * %6 = user_y(%3) | |||
| * %7 = user_z(%4) | |||
| * ---> | |||
| * %1 = call @graph_kernel(p1, p2) | |||
| * %2 = tuple_getitem(%1, 0) | |||
| * %3 = tuple_getitem(%1, 1) | |||
| * %4 = user_x(%2) | |||
| * %5 = user_y(%2) | |||
| * %6 = user_z(%3) | |||
| */ | |||
| class MergeRepeatedGetitem : public Pass { | |||
| public: | |||
| bool Run(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto todos = FindGraphKernelsWithMultiOutput(func_graph); | |||
| bool changed = false; | |||
| for (auto node : todos) { | |||
| AnfNodePtrList getitem_list; | |||
| changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed; | |||
| } | |||
| return changed; | |||
| } | |||
| }; | |||
| /* Merge the get_item nodes that have same index. | |||
| * subgraph graph_kernel(%para1, %para2) | |||
| * %1 = TensorAdd(%para1, %para2) | |||
| * %2 = Neg(%1) | |||
| * %3 = make_tuple(%1, %2) | |||
| * return (%3) | |||
| * %1 = call @graph_kernel(%p1, %p2) | |||
| * %2 = tuple_getitem(%1, 0) | |||
| * %3 = tuple_getitem(%1, 1) | |||
| * %4 = ControlDepend(%0, %2) | |||
| * %5 = other_user(%3) | |||
| * ---> | |||
| * subgraph graph_kernel(%para1, %para2) | |||
| * %1 = TensorAdd(%para1, %para2) | |||
| * %2 = Neg(%1) | |||
| * %3 = make_tuple(%1, %2) | |||
| * return (%3) | |||
| * %1 = call @graph_kernel(%p1, %p2) | |||
| * %3 = tuple_getitem(%1, 1) | |||
| * %4 = ControlDepend(%0, %3) | |||
| * %5 = other_user(%3) | |||
| * | |||
| * Then the output 0 can be eliminate in the later pass. | |||
| */ | |||
| class EliminateGetitemForControlDepend : public Pass { | |||
| public: | |||
| bool Run(const FuncGraphPtr &func_graph) { | |||
| auto todos = FindGraphKernelsWithMultiOutput(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool changed = false; | |||
| for (const auto &node : todos) { | |||
| getitems_.clear(); | |||
| GetGraphKernelGetitemList(mng, node, &getitems_, false); | |||
| if (getitems_.empty()) continue; | |||
| indexes_.clear(); | |||
| GetIndexesToControlDepend(mng); | |||
| FilterRedundantOutputs(node); | |||
| if (indexes_.empty()) continue; | |||
| size_t index = GetFinalIndex(node); | |||
| changed = ReplaceGetitems(mng, index) || changed; | |||
| } | |||
| return changed; | |||
| } | |||
| private: | |||
| AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs. | |||
| std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated. | |||
| bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(getitems_[index]); | |||
| bool changed = false; | |||
| for (auto i : indexes_) { | |||
| if (i != index) { | |||
| MS_EXCEPTION_IF_NULL(getitems_[i]); | |||
| mng->Replace(getitems_[i], getitems_[index]); | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| // Find the redundant output index. | |||
| // the real output should have multiple users. | |||
| void FilterRedundantOutputs(const AnfNodePtr &node) { | |||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| auto &users = mng->node_users(); | |||
| auto maketuple = func_graph->output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(maketuple); | |||
| std::vector<size_t> result; | |||
| for (auto i : indexes_) { | |||
| auto real_output = maketuple->input(i); | |||
| if (users[real_output].size() > 1) { | |||
| result.push_back(i); | |||
| } | |||
| } | |||
| indexes_ = std::move(result); | |||
| } | |||
| // Get the nodes that only have ControlDepend users. | |||
| void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) { | |||
| for (size_t i = 0; i < getitems_.size(); ++i) { | |||
| const AnfNodePtr &getitem = getitems_[i]; | |||
| if (getitem == nullptr) { | |||
| continue; | |||
| } | |||
| const auto &getitem_user = mng->node_users()[getitem]; | |||
| if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) { | |||
| return IsPrimitiveCNode(user.first, prim::kPrimControlDepend); | |||
| })) { | |||
| indexes_.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| size_t GetFinalIndex(const AnfNodePtr &node) { | |||
| auto is_redundant_index = [this](size_t i) { | |||
| return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end(); | |||
| }; | |||
| for (size_t i = 0; i < getitems_.size(); ++i) { | |||
| if (getitems_[i] != nullptr && !is_redundant_index(i)) { | |||
| return i; | |||
| } | |||
| } | |||
| return indexes_[0]; | |||
| } | |||
| }; | |||
| } // namespace | |||
| // Remove the output without user or with virtual user (like ControlDepend) | |||
| bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph); | |||
| changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed; | |||
| changed = Process(func_graph) || changed; | |||
| return changed; | |||
| } | |||
| void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) { | |||
| if (offset == 0) return; | |||
| MS_EXCEPTION_IF_NULL(getitem); | |||
| int64_t index = SizeToLong(GetIndex(getitem)); | |||
| if (offset > index) { | |||
| MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString(); | |||
| } | |||
| index -= offset; | |||
| auto idx_node = NewValueNode(MakeValue<int64_t>(index)); | |||
| auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index)); | |||
| idx_node->set_abstract(abstract); | |||
| idx_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node); | |||
| } | |||
| AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) { | |||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto old_maketuple = func_graph->output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(old_maketuple); | |||
| AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| AbstractBasePtrList abstract_list; | |||
| int64_t offset = 0; | |||
| for (size_t i = 0; i < getitems.size(); ++i) { | |||
| if (getitems[i] == nullptr) { | |||
| offset++; | |||
| } else { | |||
| new_maketuple_inputs.push_back(old_maketuple->input(i + 1)); | |||
| abstract_list.push_back(old_maketuple->input(i + 1)->abstract()); | |||
| UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset); | |||
| } | |||
| } | |||
| if (offset == 0) return nullptr; | |||
| if (new_maketuple_inputs.size() == 1) { | |||
| MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty"; | |||
| } | |||
| if (new_maketuple_inputs.size() == 2) { | |||
| func_graph->set_output(new_maketuple_inputs.back()); | |||
| } else { | |||
| auto make_tuple = func_graph->NewCNode(new_maketuple_inputs); | |||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| make_tuple->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| func_graph->set_output(make_tuple); | |||
| } | |||
| auto old_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(old_cnode); | |||
| AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end()); | |||
| AnfNodePtrList outputs; | |||
| kernel::GetFuncGraphOutputNodes(func_graph, &outputs); | |||
| auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs); | |||
| SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | |||
| return graph_kernel_node; | |||
| } | |||
| bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto todos = FindGraphKernelsWithMultiOutput(func_graph); | |||
| bool changed = false; | |||
| for (auto node : todos) { | |||
| AnfNodePtrList getitems; | |||
| GetGraphKernelGetitemList(mng, node, &getitems, false); | |||
| auto new_node = ReplaceMakeTuple(node, getitems); | |||
| if (new_node != nullptr) { | |||
| if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) { | |||
| // only one output, remove the getitem. | |||
| auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; }); | |||
| if (i != getitems.end()) { | |||
| mng->Replace(*i, new_node); | |||
| } | |||
| } else { | |||
| mng->Replace(node, new_node); | |||
| } | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class EliminateRedundantOutput : public Pass { | |||
| public: | |||
| EliminateRedundantOutput() : Pass("eliminate_redundant_output") {} | |||
| ~EliminateRedundantOutput() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| bool Process(const FuncGraphPtr &func_graph); | |||
| void UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset); | |||
| AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_ | |||
| @@ -531,7 +531,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f | |||
| } | |||
| std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| const FuncGraphPtr &kernel_graph, | |||
| const std::string &postfix) { | |||
| auto mng = kernel_graph->manager(); | |||
| if (mng == nullptr) { | |||
| @@ -861,23 +861,6 @@ void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||
| } | |||
| } | |||
| void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) { | |||
| for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) { | |||
| if (iter->second.second == control_depend_node) { | |||
| iter = depend_prior->erase(iter); | |||
| continue; | |||
| } | |||
| ++iter; | |||
| } | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior; | |||
| InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior); | |||
| for (auto item : new_depend_prior) { | |||
| depend_prior->insert(item); | |||
| } | |||
| } | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; | |||
| @@ -65,8 +65,8 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP | |||
| void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | |||
| const AnfNodePtrList &outputs); | |||
| std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| const std::string &postfix); | |||
| const FuncGraphPtr &kernel_graph, | |||
| const std::string &postfix = ""); | |||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | |||
| std::map<std::string, AnfNodePtr> *address_node_map); | |||
| @@ -79,8 +79,6 @@ bool IsBasicFuseOp(const AnfNodePtr &node); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior); | |||
| void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend); | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); | |||
| @@ -42,7 +42,7 @@ | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #include "debug/tensor_load.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | |||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||
| @@ -822,7 +822,7 @@ void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kern | |||
| auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | |||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | |||
| pm->AddPass(std::make_shared<opt::TensorPromotion>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); | |||
| @@ -38,7 +38,7 @@ | |||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | |||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | |||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||
| @@ -171,7 +171,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | |||
| pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>()); | |||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); | |||
| pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||
| pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); | |||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list)); | |||