From: @hwhewei Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/14580/MERGE
| @@ -132,7 +132,7 @@ def Depend(value, expr): | |||||
| return value | return value | ||||
| def UpdateState(monad, expr): | |||||
| def UpdateState(monad, *exprs): | |||||
| """Implement `UpdateState`.""" | """Implement `UpdateState`.""" | ||||
| return monad | return monad | ||||
| @@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | MS_EXCEPTION_IF_NULL(node_with_index.first); | ||||
| auto real_input = node_with_index.first; | auto real_input = node_with_index.first; | ||||
| if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) { | if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) { | ||||
| input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); | |||||
| input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | MS_EXCEPTION_IF_NULL(input_node); | ||||
| AnfAlgo::SetNodeInput(node, input_node, index); | AnfAlgo::SetNodeInput(node, input_node, index); | ||||
| } | } | ||||
| @@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||||
| return node; | return node; | ||||
| } | } | ||||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const KernelSelectPtr &kernel_select) { | |||||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, | |||||
| const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node); | |||||
| for (auto &update_state : update_states) { | |||||
| manager->SetEdge(update_state.first, update_state.second, node); | |||||
| } | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | ||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| size_t out_num = AnfAlgo::GetOutputTensorNum(node); | size_t out_num = AnfAlgo::GetOutputTensorNum(node); | ||||
| @@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||||
| return cast; | return cast; | ||||
| } | } | ||||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, | |||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); | size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); | ||||
| if (outputs_num == 0) { | if (outputs_num == 0) { | ||||
| @@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| // Multiple output | // Multiple output | ||||
| return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); | |||||
| return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select); | |||||
| } | } | ||||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| @@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select); | const KernelSelectPtr &kernel_select); | ||||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, | |||||
| const KernelSelectPtr &kernel_select); | const KernelSelectPtr &kernel_select); | ||||
| CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | ||||
| @@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod | |||||
| std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | ||||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | for (auto out_getitem : manager->node_users()[bnupdate]) { | ||||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | MS_EXCEPTION_IF_NULL(out_getitem.first); | ||||
| if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { | |||||
| continue; | |||||
| } | |||||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | MS_EXCEPTION_IF_NULL(out_getitem_ptr); | ||||
| auto input2 = out_getitem_ptr->input(2); | auto input2 = out_getitem_ptr->input(2); | ||||
| @@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr | |||||
| std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | ||||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | for (auto out_getitem : manager->node_users()[bnupdate]) { | ||||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | MS_EXCEPTION_IF_NULL(out_getitem.first); | ||||
| if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { | |||||
| continue; | |||||
| } | |||||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | MS_EXCEPTION_IF_NULL(out_getitem_ptr); | ||||
| auto input2 = out_getitem_ptr->input(2); | auto input2 = out_getitem_ptr->input(2); | ||||
| @@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||||
| } else { | } else { | ||||
| int64_t prev_idx = 0; | int64_t prev_idx = 0; | ||||
| std::vector<AnfNodePtr> tuple_getitem_nodes; | std::vector<AnfNodePtr> tuple_getitem_nodes; | ||||
| std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), | |||||
| std::back_inserter(tuple_getitem_nodes), | |||||
| [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | |||||
| for (auto &user : manager->node_users()[node]) { | |||||
| if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) { | |||||
| tuple_getitem_nodes.emplace_back(user.first); | |||||
| } | |||||
| } | |||||
| std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | ||||
| for (auto &getitem : tuple_getitem_nodes) { | for (auto &getitem : tuple_getitem_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(getitem); | MS_EXCEPTION_IF_NULL(getitem); | ||||
| @@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get | |||||
| return func_graph->NewCNode(depend_nodes); | return func_graph->NewCNode(depend_nodes); | ||||
| } | } | ||||
| CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( | CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( | ||||
| const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||||
| const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto cnode = orig_cnode; | |||||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); | |||||
| if (!update_states.empty()) { | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| cnode = kernel_graph->NewCNode(orig_cnode); | |||||
| cnode->set_inputs(orig_cnode->inputs()); | |||||
| for (auto &update_state : update_states) { | |||||
| manager->SetEdge(update_state.first, update_state.second, cnode); | |||||
| } | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| auto ref_infos = op_info->ref_infos(); | auto ref_infos = op_info->ref_infos(); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| @@ -30,9 +30,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, | |||||
| const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); | |||||
| for (auto &update_state : update_states) { | |||||
| manager->SetEdge(update_state.first, update_state.second, cnode); | |||||
| } | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| @@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | MS_EXCEPTION_IF_NULL(make_tuple); | ||||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | ||||
| return make_tuple; | return make_tuple; | ||||
| } // namespace | |||||
| } | |||||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | ||||
| @@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||||
| return replace_node; | return replace_node; | ||||
| } | } | ||||
| // Multiple output | // Multiple output | ||||
| return InsertCastForMultipleOutput(func_graph, cnode); | |||||
| return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo | |||||
| kernel_graph->ReplaceInternalOutput(node, new_node); | kernel_graph->ReplaceInternalOutput(node, new_node); | ||||
| } | } | ||||
| // process output | // process output | ||||
| return InsertCastForOutput(func_graph, new_node); | |||||
| return InsertCastForOutput(func_graph, cnode, new_node); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { | if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { | ||||
| kernel_graph->ReplaceInternalOutput(node, new_node); | kernel_graph->ReplaceInternalOutput(node, new_node); | ||||
| } | } | ||||
| return InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||||
| return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" | #include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" | ||||
| #include "backend/optimizer/pass/convert_attr_to_unify_mindir.h" | #include "backend/optimizer/pass/convert_attr_to_unify_mindir.h" | ||||
| #include "backend/optimizer/pass/add_training_attr.h" | #include "backend/optimizer/pass/add_training_attr.h" | ||||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| DumpIR(file_name, kernel_graph); | DumpIR(file_name, kernel_graph); | ||||
| } | } | ||||
| } | } | ||||
| void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| // Run optimizer passes. | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto pm = std::make_shared<PassManager>("final_opt"); | |||||
| pm->AddPass(std::make_shared<OptimizeUpdateState>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| // Dump IR if save_graphs is set. | |||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| const bool save_graphs = context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||||
| if (save_graphs) { | |||||
| std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(filename, kernel_graph); | |||||
| } | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -401,11 +401,9 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||||
| } | } | ||||
| auto output_info_list = iter->second; | auto output_info_list = iter->second; | ||||
| for (const auto &output_info : output_info_list) { | for (const auto &output_info : output_info_list) { | ||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||||
| output_info.second == kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) { | |||||
| auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); | |||||
| if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || | |||||
| (cnode_name == prim::kPrimUpdateState->name())) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| output_node_list->push_back(output_info); | output_node_list->push_back(output_info); | ||||
| @@ -426,12 +424,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu | |||||
| } | } | ||||
| auto output_info_list = iter->second; | auto output_info_list = iter->second; | ||||
| for (const auto &output_info : output_info_list) { | for (const auto &output_info : output_info_list) { | ||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||||
| output_info.second == kDependAttachNodeIndex) { | |||||
| auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); | |||||
| if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || | |||||
| (cnode_name == prim::kPrimUpdateState->name())) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| size_t used_output_index; | size_t used_output_index; | ||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) { | |||||
| if (cnode_name == prim::kPrimTupleGetItem->name()) { | |||||
| used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | ||||
| } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { | } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { | ||||
| used_output_index = output_index; | used_output_index = output_index; | ||||
| @@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| // find BatchNorm's output which is a Depend | |||||
| // Find BatchNorm's output which is a Depend or UpdateState. | |||||
| for (const auto &node_index : manager->node_users()[old_node]) { | for (const auto &node_index : manager->node_users()[old_node]) { | ||||
| AnfNodePtr output = node_index.first; | AnfNodePtr output = node_index.first; | ||||
| size_t index = IntToSize(node_index.second); | size_t index = IntToSize(node_index.second); | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || | |||||
| AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { | |||||
| auto depend = output->cast<CNodePtr>(); | auto depend = output->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(depend); | MS_EXCEPTION_IF_NULL(depend); | ||||
| depend->set_input(index, new_node); | depend->set_input(index, new_node); | ||||
| @@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr | |||||
| auto output_num = output->cast<CNodePtr>()->size() - 1; | auto output_num = output->cast<CNodePtr>()->size() - 1; | ||||
| getitem_list->clear(); | getitem_list->clear(); | ||||
| getitem_list->resize(output_num, nullptr); | getitem_list->resize(output_num, nullptr); | ||||
| const auto &users = mng->node_users()[node]; | |||||
| auto users = mng->node_users()[node]; | |||||
| bool changed = false; | bool changed = false; | ||||
| AnfNodePtrList user_nodes; | |||||
| std::transform(users.begin(), users.end(), std::back_inserter(user_nodes), | |||||
| [](const std::pair<AnfNodePtr, int> &user) { return user.first; }); | |||||
| for (const auto &getitem : user_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(getitem); | |||||
| for (const auto &user : users) { | |||||
| if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { | |||||
| // Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState. | |||||
| continue; | |||||
| } | |||||
| auto &getitem = user.first; | |||||
| auto idx = GetIndex(getitem); | auto idx = GetIndex(getitem); | ||||
| if (idx >= output_num) { | if (idx >= output_num) { | ||||
| MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); | MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); | ||||
| @@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||||
| const std::vector<AnfNodePtr> &new_depend_inputs) { | const std::vector<AnfNodePtr> &new_depend_inputs) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||||
| CNodePtr new_depend = nullptr; | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||||
| if (kernel_graph == nullptr) { | if (kernel_graph == nullptr) { | ||||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||||
| auto new_depend = func_graph->NewCNode(new_depend_inputs); | |||||
| MS_EXCEPTION_IF_NULL(new_depend); | MS_EXCEPTION_IF_NULL(new_depend); | ||||
| new_depend->set_abstract(cnode->abstract()); | new_depend->set_abstract(cnode->abstract()); | ||||
| new_depend->set_scope(cnode->scope()); | new_depend->set_scope(cnode->scope()); | ||||
| } else { | |||||
| new_depend = kernel_graph->NewCNode(cnode); | |||||
| MS_EXCEPTION_IF_NULL(new_depend); | |||||
| new_depend->set_inputs(new_depend_inputs); | |||||
| return new_depend; | |||||
| } | } | ||||
| func_graph->manager()->Replace(cnode, new_depend); | |||||
| auto new_depend = kernel_graph->NewCNode(cnode); | |||||
| MS_EXCEPTION_IF_NULL(new_depend); | |||||
| new_depend->set_inputs(new_depend_inputs); | |||||
| return new_depend; | return new_depend; | ||||
| } | } | ||||
| @@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con | |||||
| auto replace_node = eliminate_node->input(kSingleInputIndex); | auto replace_node = eliminate_node->input(kSingleInputIndex); | ||||
| std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs(); | std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs(); | ||||
| new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; | new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; | ||||
| auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs); | |||||
| auto new_node = new_cnode->cast<AnfNodePtr>(); | |||||
| return new_node; | |||||
| auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs); | |||||
| func_graph->manager()->Replace(cnode, new_depend); | |||||
| return new_depend; | |||||
| } | } | ||||
| AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | ||||
| @@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const { | |||||
| return VectorRef({X, Xs}); | return VectorRef({X, Xs}); | ||||
| } | } | ||||
| std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) { | |||||
| if (node == nullptr || !node->isa<CNode>()) { | |||||
| return std::pair<AnfNodePtr, size_t>(nullptr, 0); | |||||
| } | |||||
| // get real input of depend and update state. | |||||
| size_t replace_input_index = 0; | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||||
| replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend; | |||||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { | |||||
| replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput; | |||||
| } else { | |||||
| return std::pair<AnfNodePtr, size_t>(nullptr, 0); | |||||
| } | |||||
| // check whether real input is cast or trans data | |||||
| auto real_input = node->cast<CNodePtr>()->input(replace_input_index); | |||||
| if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) || | |||||
| AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) || | |||||
| AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) { | |||||
| return std::pair<AnfNodePtr, size_t>(node, replace_input_index); | |||||
| } | |||||
| return SearchTransDataAndCast(real_input, false); | |||||
| std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) { | |||||
| // Search Depend and UpdateState only. | |||||
| if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) { | |||||
| return {}; | |||||
| } | |||||
| // Find inputs which is Cast or TransData. | |||||
| std::vector<size_t> result; | |||||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||||
| auto &input = cnode->input(i); | |||||
| if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) || | |||||
| AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) || | |||||
| AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { | |||||
| result.emplace_back(i); | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | } | ||||
| const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | |||||
| auto cnode = dyn_cast<CNode>(node); | |||||
| if (cnode == nullptr) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Get the cnode with repalce input index | |||||
| auto cnode_with_input_index = SearchTransDataAndCast(node, true); | |||||
| if (cnode_with_input_index.first == nullptr) { | |||||
| // Search inputs to be replaced. | |||||
| auto candidate_inputs = SearchTransDataAndCast(cnode); | |||||
| if (candidate_inputs.empty()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| size_t replace_index = cnode_with_input_index.second; | |||||
| auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||||
| // Get new node which will act as new input of depend or UpdateState. | |||||
| std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs(); | |||||
| auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index); | |||||
| if (replace_node == nullptr) { | |||||
| return nullptr; | |||||
| // Get new nodes which will act as new inputs of Depend or UpdateState. | |||||
| std::vector<AnfNodePtr> new_inputs = cnode->inputs(); | |||||
| bool inputs_changed = false; | |||||
| for (auto index : candidate_inputs) { | |||||
| auto replace_node = GetConvertNode(func_graph, cnode, index); | |||||
| if (replace_node != nullptr) { | |||||
| new_inputs[index] = replace_node; | |||||
| inputs_changed = true; | |||||
| } | |||||
| } | } | ||||
| new_depend_inputs[replace_index] = replace_node; | |||||
| auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs); | |||||
| if (new_depend == nullptr) { | |||||
| if (!inputs_changed) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Create a new Depend node to replace the old one if inputs changed. | |||||
| auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs); | |||||
| func_graph->manager()->Replace(cnode, new_depend); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| constexpr size_t kInputIndex = 1; | |||||
| constexpr size_t kAttachIndex = 2; | |||||
| constexpr size_t kAdditionalAttachIndex = 3; | |||||
| const BaseRef OptimizeUpdateState::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({prim::kPrimUpdateState, Xs}); | |||||
| } | |||||
| const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| auto update_state = dyn_cast<CNode>(node); | |||||
| MS_EXCEPTION_IF_NULL(update_state); | |||||
| if (update_state->size() <= kAdditionalAttachIndex) { | |||||
| // Skip UpdateState nodes with no additional attaches. | |||||
| return nullptr; | |||||
| } | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| new_inputs.emplace_back(update_state->input(0)); | |||||
| new_inputs.emplace_back(update_state->input(kInputIndex)); | |||||
| new_inputs.emplace_back(update_state->input(kAttachIndex)); | |||||
| for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) { | |||||
| auto &attach = update_state->input(i); | |||||
| auto &users = node_users[attach]; | |||||
| if ((users.size() == 1) && (users.front().first == update_state)) { | |||||
| // If the only user of attach is the UpdateState node, drop the attach node. | |||||
| continue; | |||||
| } | |||||
| new_inputs.emplace_back(attach); | |||||
| } | |||||
| if (new_inputs.size() == update_state->size()) { | |||||
| // Attaches not changed. | |||||
| return nullptr; | |||||
| } | |||||
| // Attaches changed, make a new UpdateState. | |||||
| auto new_update_state = func_graph->NewCNode(new_inputs); | |||||
| new_update_state->set_abstract(update_state->abstract()); | |||||
| new_update_state->set_scope(update_state->scope()); | |||||
| return new_update_state; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class OptimizeUpdateState : public PatternProcessPass { | |||||
| public: | |||||
| explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {} | |||||
| ~OptimizeUpdateState() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||||
| @@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_ | |||||
| {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | ||||
| root_graph->set_output(make_tuple); | root_graph->set_output(make_tuple); | ||||
| } | } | ||||
| AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||||
| AnfNodeIndexSet update_states; | |||||
| for (auto &user : manager->node_users()[node]) { | |||||
| if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) { | |||||
| update_states.insert(user); | |||||
| } | |||||
| } | |||||
| return update_states; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm { | |||||
| static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | ||||
| std::set<AnfNodePtr> *visited); | std::set<AnfNodePtr> *visited); | ||||
| static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph); | static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph); | ||||
| static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() { | |||||
| void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "HardwareOptimize start!"; | MS_LOG(INFO) << "HardwareOptimize start!"; | ||||
| opt::AscendBackendOptimization(kernel_graph); | opt::AscendBackendOptimization(kernel_graph); | ||||
| FinalOptimize(kernel_graph); | |||||
| GraphKernelOptimize(kernel_graph); | GraphKernelOptimize(kernel_graph); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| @@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||||
| SetKernelInfo(graph.get()); | SetKernelInfo(graph.get()); | ||||
| MS_LOG(INFO) << "Set kernel info end"; | MS_LOG(INFO) << "Set kernel info end"; | ||||
| Optimize(graph); | Optimize(graph); | ||||
| FinalOptimize(graph); | |||||
| MS_LOG(INFO) << "Build kernel"; | MS_LOG(INFO) << "Build kernel"; | ||||
| BuildKernel(graph.get()); | BuildKernel(graph.get()); | ||||
| // Remove reorder after PS feature finish adapting push/pull in auto_monad. | // Remove reorder after PS feature finish adapting push/pull in auto_monad. | ||||
| @@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { | |||||
| SelectKernel(graph); | SelectKernel(graph); | ||||
| // Graph optimization relevant to device data format | // Graph optimization relevant to device data format | ||||
| HardwareOptimize(graph); | HardwareOptimize(graph); | ||||
| // Run final optimization | |||||
| FinalOptimize(graph); | |||||
| // Graph kernel fusion optimization | // Graph kernel fusion optimization | ||||
| GraphKernelOptimize(graph); | GraphKernelOptimize(graph); | ||||
| // Start gpu kernel runtime | // Start gpu kernel runtime | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <set> | #include <set> | ||||
| #include <queue> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) { | |||||
| } | } | ||||
| } | } | ||||
| void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const { | |||||
| MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id(); | |||||
| opt::CommonFinalOptimization(graph); | |||||
| MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id(); | |||||
| } | |||||
| void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | ||||
| #ifdef ENABLE_DUMP_IR | #ifdef ENABLE_DUMP_IR | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| @@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual void UpdateOutputTensors(const VectorRef *outputs, | virtual void UpdateOutputTensors(const VectorRef *outputs, | ||||
| const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node); | const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node); | ||||
| virtual void UnifyMindIR(const KernelGraphPtr &graph) {} | virtual void UnifyMindIR(const KernelGraphPtr &graph) {} | ||||
| virtual void FinalOptimize(const KernelGraphPtr &graph) const; | |||||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } | virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } | ||||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | ||||
| virtual void BuildGraphImpl(GraphId) {} | virtual void BuildGraphImpl(GraphId) {} | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "pipeline/jit/parse/parse_base.h" | #include "pipeline/jit/parse/parse_base.h" | ||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| #include "pipeline/jit/static_analysis/auto_monad.h" | #include "pipeline/jit/static_analysis/auto_monad.h" | ||||
| #include "pipeline/jit/static_analysis/order_enforce.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | #include "pipeline/jit/static_analysis/static_analysis.h" | ||||
| #include "pipeline/jit/static_analysis/program_specialize.h" | #include "pipeline/jit/static_analysis/program_specialize.h" | ||||
| @@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool OrderEnforceAction(const ResourcePtr &res) { | |||||
| if (res->manager() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null"; | |||||
| } | |||||
| auto func_graph = res->func_graph(); | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null"; | |||||
| } | |||||
| pipeline::OrderEnforce(func_graph); | |||||
| return true; | |||||
| } | |||||
| bool InferenceOptPrepareAction(const ResourcePtr &res) { | bool InferenceOptPrepareAction(const ResourcePtr &res) { | ||||
| if (res->manager() == nullptr) { | if (res->manager() == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; | MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; | ||||
| @@ -752,6 +765,7 @@ std::vector<ActionItem> GePipeline() { | |||||
| // Add opt-stage python pass stub | // Add opt-stage python pass stub | ||||
| actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | ||||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | ||||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| return actions; | return actions; | ||||
| } | } | ||||
| @@ -765,6 +779,8 @@ std::vector<ActionItem> VmPipeline() { | |||||
| // Add opt-stage python pass stub | // Add opt-stage python pass stub | ||||
| actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); | actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); | ||||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| if (ps::PSContext::instance()->is_worker()) { | if (ps::PSContext::instance()->is_worker()) { | ||||
| @@ -784,6 +800,7 @@ std::vector<ActionItem> VmPipeline() { | |||||
| std::vector<ActionItem> PServerPipeline() { | std::vector<ActionItem> PServerPipeline() { | ||||
| auto actions = CommonPipeline(); | auto actions = CommonPipeline(); | ||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | ||||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| actions.emplace_back(std::make_pair("pserver", StartPSServerAction)); | actions.emplace_back(std::make_pair("pserver", StartPSServerAction)); | ||||
| return actions; | return actions; | ||||
| @@ -0,0 +1,258 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pipeline/jit/static_analysis/order_enforce.h" | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <queue> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <utility> | |||||
| #include "base/core_ops.h" | |||||
| namespace mindspore::pipeline { | |||||
| namespace { | |||||
| class OrderEnforcer { | |||||
| public: | |||||
| explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| MS_EXCEPTION_IF_NULL(manager_); | |||||
| } | |||||
| ~OrderEnforcer() = default; | |||||
| void Run() { | |||||
| auto nodes = MakeTopoSortMap(); | |||||
| for (auto &node : nodes) { | |||||
| HandleNode(node); | |||||
| } | |||||
| } | |||||
| private: | |||||
| AnfNodePtrList MakeTopoSortMap() { | |||||
| auto nodes = TopoSort(func_graph_->get_return()); | |||||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||||
| topo_sort_map_.emplace(nodes[i], i); | |||||
| } | |||||
| return nodes; | |||||
| } | |||||
| void HandleNode(const AnfNodePtr &node) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) { | |||||
| // Skip nodes other than UpdateState. | |||||
| return; | |||||
| } | |||||
| auto update_state = node->cast<CNodePtr>(); | |||||
| if (!HasAbstractUMonad(update_state->input(1))) { | |||||
| // Skip UpdateStates for IO. | |||||
| return; | |||||
| } | |||||
| auto updated_refs = FindUpdatedRefs(update_state); | |||||
| if (updated_refs.empty()) { | |||||
| // Skip UpdateStates that do not have updated refs. | |||||
| return; | |||||
| } | |||||
| auto &attach = update_state->input(2); | |||||
| if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { | |||||
| // Handle UpdateState with Load. | |||||
| EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs); | |||||
| } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { | |||||
| // Handle UpdateState with MakeTuple. | |||||
| EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs); | |||||
| } | |||||
| } | |||||
| std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) { | |||||
| std::unordered_set<AnfNodePtr> updated_refs; | |||||
| auto &users = manager_->node_users()[update_state]; | |||||
| for (auto &user : users) { | |||||
| auto cnode = dyn_cast<CNode>(user.first); | |||||
| if (cnode == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) || | |||||
| cnode->IsApply(prim::kPrimUpdateState)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &input : cnode->inputs()) { | |||||
| if (IsRef(input)) { | |||||
| updated_refs.insert(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| return updated_refs; | |||||
| } | |||||
| bool IsRef(const AnfNodePtr &node) { | |||||
| auto &abs = node->abstract(); | |||||
| return abs != nullptr && abs->isa<abstract::AbstractRef>(); | |||||
| } | |||||
| void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load, | |||||
| const std::unordered_set<AnfNodePtr> &refs) { | |||||
| if (refs.find(load->input(1)) == refs.end()) { | |||||
| // Skip if loaded parameter is not updated. | |||||
| return; | |||||
| } | |||||
| // Find load users, ignore processed nodes. | |||||
| auto load_users = FindLoadUsers(load, update_state); | |||||
| // Find load users that not depend on the UpdateState, | |||||
| // and than let UpdateState depend on them. | |||||
| AddInputEdges(update_state, load_users); | |||||
| } | |||||
| void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||||
| const std::unordered_set<AnfNodePtr> &refs) { | |||||
| // The UpdateState should be the only one user of the make_tuple. | |||||
| // for performance, we only check the number of output edges. | |||||
| if (manager_->node_users()[make_tuple].size() != 1) { | |||||
| return; | |||||
| } | |||||
| // Find load users from the tuple of Load nodes. | |||||
| std::unordered_set<AnfNodePtr> all_load_users; | |||||
| auto &inputs = make_tuple->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| auto &input = inputs.at(i); | |||||
| if (!IsPrimitiveCNode(input, prim::kPrimLoad)) { | |||||
| // Skip non-Load nodes. | |||||
| continue; | |||||
| } | |||||
| auto load = input->cast<CNodePtr>(); | |||||
| if (refs.find(load->input(1)) == refs.end()) { | |||||
| // Skip if loaded parameter is not updated. | |||||
| continue; | |||||
| } | |||||
| auto load_users = FindLoadUsers(load, make_tuple); | |||||
| all_load_users.insert(load_users.begin(), load_users.end()); | |||||
| } | |||||
| // Find load users that not depend on the UpdateState, | |||||
| // and than let UpdateState depend on them. | |||||
| AddInputEdges(update_state, all_load_users); | |||||
| } | |||||
| // Add load users as input edges of the update_state node. | |||||
| void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) { | |||||
| auto sorted_load_users = SortLoadUsers(load_users); | |||||
| for (auto &load_user : sorted_load_users) { | |||||
| if (!IsDependOn(load_user, update_state)) { | |||||
| processed_nodes_.insert(load_user); | |||||
| manager_->AddEdge(update_state, load_user); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Sort load users by their topo sort order. | |||||
| std::vector<AnfNodePtr> SortLoadUsers(const std::unordered_set<AnfNodePtr> &load_users) { | |||||
| std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()}; | |||||
| std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); }); | |||||
| return vec; | |||||
| } | |||||
| // Check if the load user node depend on the given UpdateState node. | |||||
| bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) { | |||||
| size_t update_state_order = topo_sort_map_[update_state]; | |||||
| if (topo_sort_map_[load_user] < update_state_order) { | |||||
| return false; | |||||
| } | |||||
| auto user_cnode = dyn_cast<CNode>(load_user); | |||||
| if (user_cnode == nullptr) { | |||||
| return false; | |||||
| } | |||||
| size_t seen = NewSeenGeneration(); | |||||
| std::queue<CNodePtr> q; | |||||
| user_cnode->seen_ = seen; | |||||
| q.push(user_cnode); | |||||
| while (!q.empty()) { | |||||
| auto cnode = q.front(); | |||||
| q.pop(); | |||||
| for (auto &input : cnode->inputs()) { | |||||
| if (input == update_state) { | |||||
| // Dependency found. | |||||
| return true; | |||||
| } | |||||
| if (input->seen_ == seen) { | |||||
| // Skip visited nodes. | |||||
| continue; | |||||
| } | |||||
| if (topo_sort_map_[input] < update_state_order) { | |||||
| // Skip input nodes that before the UpdateState node. | |||||
| continue; | |||||
| } | |||||
| auto input_cnode = dyn_cast<CNode>(input); | |||||
| if (input_cnode != nullptr) { | |||||
| input_cnode->seen_ = seen; | |||||
| q.push(input_cnode); | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||||
| return topo_sort_map_[node1] < topo_sort_map_[node2]; | |||||
| } | |||||
| // Find Load users as the candidate nodes to enforce order of execution. | |||||
| std::unordered_set<AnfNodePtr> FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) { | |||||
| auto &node_users = manager_->node_users(); | |||||
| auto iter = node_users.find(load); | |||||
| if (iter == node_users.end()) { | |||||
| return {}; | |||||
| } | |||||
| std::unordered_set<AnfNodePtr> load_users; | |||||
| auto &users = iter->second; | |||||
| for (auto &user : users) { | |||||
| auto &user_node = user.first; | |||||
| if (user_node == exclude) { | |||||
| // Skip excluded node. | |||||
| continue; | |||||
| } | |||||
| if (processed_nodes_.find(user_node) != processed_nodes_.end()) { | |||||
| // Skip processed nodes. | |||||
| continue; | |||||
| } | |||||
| auto cnode = dyn_cast<CNode>(user_node); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| const bool has_u_input = | |||||
| std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); }); | |||||
| if (has_u_input) { | |||||
| // Skip nodes with memory side effects, which use u input. | |||||
| continue; | |||||
| } | |||||
| load_users.insert(cnode); | |||||
| } | |||||
| return load_users; | |||||
| } | |||||
| private: | |||||
| const FuncGraphPtr &func_graph_; | |||||
| FuncGraphManagerPtr manager_; | |||||
| std::unordered_map<AnfNodePtr, size_t> topo_sort_map_; | |||||
| std::unordered_set<AnfNodePtr> processed_nodes_; | |||||
| }; | |||||
| } // namespace | |||||
| // | |||||
| // Enforce order of execution for Load users node. | |||||
| // | |||||
| void OrderEnforce(const FuncGraphPtr &func_graph) { | |||||
| OrderEnforcer enforcer(func_graph); | |||||
| enforcer.Run(); | |||||
| } | |||||
| } // namespace mindspore::pipeline | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||||
| #define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore::pipeline { | |||||
| // Enforce order of execution of the given graph. | |||||
| void OrderEnforce(const FuncGraphPtr &func_graph); | |||||
| } // namespace mindspore::pipeline | |||||
| #endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||||
| @@ -1456,7 +1456,10 @@ def test_while_forward(): | |||||
| assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) | assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) | ||||
| @pytest.mark.skip(reason="not supported yet") | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_multi_add_assign(): | def test_multi_add_assign(): | ||||
| class Net(Cell): | class Net(Cell): | ||||
| def __init__(self, i1): | def __init__(self, i1): | ||||
| @@ -1493,7 +1496,10 @@ def test_multi_add_assign(): | |||||
| np.testing.assert_array_equal(outputs, expects) | np.testing.assert_array_equal(outputs, expects) | ||||
| @pytest.mark.skip(reason="not supported yet") | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_multi_abs_add_assign(): | def test_multi_abs_add_assign(): | ||||
| class Net(Cell): | class Net(Cell): | ||||
| def __init__(self, para): | def __init__(self, para): | ||||