From: @hwhewei Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/14580/MERGE
| @@ -132,7 +132,7 @@ def Depend(value, expr): | |||
| return value | |||
| def UpdateState(monad, expr): | |||
| def UpdateState(monad, *exprs): | |||
| """Implement `UpdateState`.""" | |||
| return monad | |||
| @@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| auto real_input = node_with_index.first; | |||
| if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) { | |||
| input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); | |||
| input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| AnfAlgo::SetNodeInput(node, input_node, index); | |||
| } | |||
| @@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||
| return node; | |||
| } | |||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select) { | |||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, | |||
| const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node); | |||
| for (auto &update_state : update_states) { | |||
| manager->SetEdge(update_state.first, update_state.second, node); | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| size_t out_num = AnfAlgo::GetOutputTensorNum(node); | |||
| @@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||
| return cast; | |||
| } | |||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select) { | |||
| size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); | |||
| if (outputs_num == 0) { | |||
| @@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| return new_node; | |||
| } | |||
| // Multiple output | |||
| return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); | |||
| return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select); | |||
| } | |||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| @@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select); | |||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select); | |||
| CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| @@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod | |||
| std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | |||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | |||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | |||
| if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | |||
| auto input2 = out_getitem_ptr->input(2); | |||
| @@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr | |||
| std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); | |||
| for (auto out_getitem : manager->node_users()[bnupdate]) { | |||
| MS_EXCEPTION_IF_NULL(out_getitem.first); | |||
| if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(out_getitem_ptr); | |||
| auto input2 = out_getitem_ptr->input(2); | |||
| @@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||
| } else { | |||
| int64_t prev_idx = 0; | |||
| std::vector<AnfNodePtr> tuple_getitem_nodes; | |||
| std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), | |||
| std::back_inserter(tuple_getitem_nodes), | |||
| [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | |||
| for (auto &user : manager->node_users()[node]) { | |||
| if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) { | |||
| tuple_getitem_nodes.emplace_back(user.first); | |||
| } | |||
| } | |||
| std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | |||
| for (auto &getitem : tuple_getitem_nodes) { | |||
| MS_EXCEPTION_IF_NULL(getitem); | |||
| @@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get | |||
| return func_graph->NewCNode(depend_nodes); | |||
| } | |||
| CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( | |||
| const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||
| const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto cnode = orig_cnode; | |||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); | |||
| if (!update_states.empty()) { | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| cnode = kernel_graph->NewCNode(orig_cnode); | |||
| cnode->set_inputs(orig_cnode->inputs()); | |||
| for (auto &update_state : update_states) { | |||
| manager->SetEdge(update_state.first, update_state.second, cnode); | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| auto ref_infos = op_info->ref_infos(); | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| @@ -30,9 +30,16 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, | |||
| const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); | |||
| for (auto &update_state : update_states) { | |||
| manager->SetEdge(update_state.first, update_state.second, cnode); | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| AbstractBasePtrList abstract_list; | |||
| make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| @@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| return make_tuple; | |||
| } // namespace | |||
| } | |||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | |||
| @@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||
| return replace_node; | |||
| } | |||
| // Multiple output | |||
| return InsertCastForMultipleOutput(func_graph, cnode); | |||
| return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode); | |||
| } | |||
| } // namespace | |||
| @@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo | |||
| kernel_graph->ReplaceInternalOutput(node, new_node); | |||
| } | |||
| // process output | |||
| return InsertCastForOutput(func_graph, new_node); | |||
| return InsertCastForOutput(func_graph, cnode, new_node); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { | |||
| kernel_graph->ReplaceInternalOutput(node, new_node); | |||
| } | |||
| return InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||
| return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" | |||
| #include "backend/optimizer/pass/convert_attr_to_unify_mindir.h" | |||
| #include "backend/optimizer/pass/add_training_attr.h" | |||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||
| #include "utils/ms_context.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| DumpIR(file_name, kernel_graph); | |||
| } | |||
| } | |||
| void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // Run optimizer passes. | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto pm = std::make_shared<PassManager>("final_opt"); | |||
| pm->AddPass(std::make_shared<OptimizeUpdateState>()); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| // Dump IR if save_graphs is set. | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const bool save_graphs = context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| if (save_graphs) { | |||
| std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||
| DumpIR(filename, kernel_graph); | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -401,11 +401,9 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||
| } | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) { | |||
| auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); | |||
| if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || | |||
| (cnode_name == prim::kPrimUpdateState->name())) { | |||
| continue; | |||
| } | |||
| output_node_list->push_back(output_info); | |||
| @@ -426,12 +424,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu | |||
| } | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); | |||
| if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || | |||
| (cnode_name == prim::kPrimUpdateState->name())) { | |||
| continue; | |||
| } | |||
| size_t used_output_index; | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) { | |||
| if (cnode_name == prim::kPrimTupleGetItem->name()) { | |||
| used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | |||
| } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { | |||
| used_output_index = output_index; | |||
| @@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| // find BatchNorm's output which is a Depend | |||
| // Find BatchNorm's output which is a Depend or UpdateState. | |||
| for (const auto &node_index : manager->node_users()[old_node]) { | |||
| AnfNodePtr output = node_index.first; | |||
| size_t index = IntToSize(node_index.second); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || | |||
| AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { | |||
| auto depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend); | |||
| depend->set_input(index, new_node); | |||
| @@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr | |||
| auto output_num = output->cast<CNodePtr>()->size() - 1; | |||
| getitem_list->clear(); | |||
| getitem_list->resize(output_num, nullptr); | |||
| const auto &users = mng->node_users()[node]; | |||
| auto users = mng->node_users()[node]; | |||
| bool changed = false; | |||
| AnfNodePtrList user_nodes; | |||
| std::transform(users.begin(), users.end(), std::back_inserter(user_nodes), | |||
| [](const std::pair<AnfNodePtr, int> &user) { return user.first; }); | |||
| for (const auto &getitem : user_nodes) { | |||
| MS_EXCEPTION_IF_NULL(getitem); | |||
| for (const auto &user : users) { | |||
| if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { | |||
| // Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState. | |||
| continue; | |||
| } | |||
| auto &getitem = user.first; | |||
| auto idx = GetIndex(getitem); | |||
| if (idx >= output_num) { | |||
| MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); | |||
| @@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||
| const std::vector<AnfNodePtr> &new_depend_inputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_depend = nullptr; | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| if (kernel_graph == nullptr) { | |||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| auto new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_abstract(cnode->abstract()); | |||
| new_depend->set_scope(cnode->scope()); | |||
| } else { | |||
| new_depend = kernel_graph->NewCNode(cnode); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_inputs(new_depend_inputs); | |||
| return new_depend; | |||
| } | |||
| func_graph->manager()->Replace(cnode, new_depend); | |||
| auto new_depend = kernel_graph->NewCNode(cnode); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_inputs(new_depend_inputs); | |||
| return new_depend; | |||
| } | |||
| @@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con | |||
| auto replace_node = eliminate_node->input(kSingleInputIndex); | |||
| std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs(); | |||
| new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; | |||
| auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs); | |||
| auto new_node = new_cnode->cast<AnfNodePtr>(); | |||
| return new_node; | |||
| auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs); | |||
| func_graph->manager()->Replace(cnode, new_depend); | |||
| return new_depend; | |||
| } | |||
| AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| @@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const { | |||
| return VectorRef({X, Xs}); | |||
| } | |||
| std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return std::pair<AnfNodePtr, size_t>(nullptr, 0); | |||
| } | |||
| // get real input of depend and update state. | |||
| size_t replace_input_index = 0; | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||
| replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend; | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { | |||
| replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput; | |||
| } else { | |||
| return std::pair<AnfNodePtr, size_t>(nullptr, 0); | |||
| } | |||
| // check whether real input is cast or trans data | |||
| auto real_input = node->cast<CNodePtr>()->input(replace_input_index); | |||
| if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) || | |||
| AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) || | |||
| AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) { | |||
| return std::pair<AnfNodePtr, size_t>(node, replace_input_index); | |||
| } | |||
| return SearchTransDataAndCast(real_input, false); | |||
| std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) { | |||
| // Search Depend and UpdateState only. | |||
| if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) { | |||
| return {}; | |||
| } | |||
| // Find inputs which is Cast or TransData. | |||
| std::vector<size_t> result; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| auto &input = cnode->input(i); | |||
| if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) || | |||
| AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) || | |||
| AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { | |||
| result.emplace_back(i); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| auto cnode = dyn_cast<CNode>(node); | |||
| if (cnode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // Get the cnode with repalce input index | |||
| auto cnode_with_input_index = SearchTransDataAndCast(node, true); | |||
| if (cnode_with_input_index.first == nullptr) { | |||
| // Search inputs to be replaced. | |||
| auto candidate_inputs = SearchTransDataAndCast(cnode); | |||
| if (candidate_inputs.empty()) { | |||
| return nullptr; | |||
| } | |||
| size_t replace_index = cnode_with_input_index.second; | |||
| auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend_cnode); | |||
| // Get new node which will act as new input of depend or UpdateState. | |||
| std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs(); | |||
| auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index); | |||
| if (replace_node == nullptr) { | |||
| return nullptr; | |||
| // Get new nodes which will act as new inputs of Depend or UpdateState. | |||
| std::vector<AnfNodePtr> new_inputs = cnode->inputs(); | |||
| bool inputs_changed = false; | |||
| for (auto index : candidate_inputs) { | |||
| auto replace_node = GetConvertNode(func_graph, cnode, index); | |||
| if (replace_node != nullptr) { | |||
| new_inputs[index] = replace_node; | |||
| inputs_changed = true; | |||
| } | |||
| } | |||
| new_depend_inputs[replace_index] = replace_node; | |||
| auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs); | |||
| if (new_depend == nullptr) { | |||
| if (!inputs_changed) { | |||
| return nullptr; | |||
| } | |||
| // Create a new Depend node to replace the old one if inputs changed. | |||
| auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs); | |||
| func_graph->manager()->Replace(cnode, new_depend); | |||
| return nullptr; | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/pass/optimize_updatestate.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kInputIndex = 1; | |||
| constexpr size_t kAttachIndex = 2; | |||
| constexpr size_t kAdditionalAttachIndex = 3; | |||
| const BaseRef OptimizeUpdateState::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimUpdateState, Xs}); | |||
| } | |||
| const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| auto update_state = dyn_cast<CNode>(node); | |||
| MS_EXCEPTION_IF_NULL(update_state); | |||
| if (update_state->size() <= kAdditionalAttachIndex) { | |||
| // Skip UpdateState nodes with no additional attaches. | |||
| return nullptr; | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| new_inputs.emplace_back(update_state->input(0)); | |||
| new_inputs.emplace_back(update_state->input(kInputIndex)); | |||
| new_inputs.emplace_back(update_state->input(kAttachIndex)); | |||
| for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) { | |||
| auto &attach = update_state->input(i); | |||
| auto &users = node_users[attach]; | |||
| if ((users.size() == 1) && (users.front().first == update_state)) { | |||
| // If the only user of attach is the UpdateState node, drop the attach node. | |||
| continue; | |||
| } | |||
| new_inputs.emplace_back(attach); | |||
| } | |||
| if (new_inputs.size() == update_state->size()) { | |||
| // Attaches not changed. | |||
| return nullptr; | |||
| } | |||
| // Attaches changed, make a new UpdateState. | |||
| auto new_update_state = func_graph->NewCNode(new_inputs); | |||
| new_update_state->set_abstract(update_state->abstract()); | |||
| new_update_state->set_scope(update_state->scope()); | |||
| return new_update_state; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class OptimizeUpdateState : public PatternProcessPass { | |||
| public: | |||
| explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {} | |||
| ~OptimizeUpdateState() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ | |||
| @@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_ | |||
| {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | |||
| root_graph->set_output(make_tuple); | |||
| } | |||
| AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||
| AnfNodeIndexSet update_states; | |||
| for (auto &user : manager->node_users()[node]) { | |||
| if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) { | |||
| update_states.insert(user); | |||
| } | |||
| } | |||
| return update_states; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm { | |||
| static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | |||
| std::set<AnfNodePtr> *visited); | |||
| static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph); | |||
| static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() { | |||
| void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| MS_LOG(INFO) << "HardwareOptimize start!"; | |||
| opt::AscendBackendOptimization(kernel_graph); | |||
| FinalOptimize(kernel_graph); | |||
| GraphKernelOptimize(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| SetKernelInfo(graph.get()); | |||
| MS_LOG(INFO) << "Set kernel info end"; | |||
| Optimize(graph); | |||
| FinalOptimize(graph); | |||
| MS_LOG(INFO) << "Build kernel"; | |||
| BuildKernel(graph.get()); | |||
| // Remove reorder after PS feature finish adapting push/pull in auto_monad. | |||
| @@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { | |||
| SelectKernel(graph); | |||
| // Graph optimization relevant to device data format | |||
| HardwareOptimize(graph); | |||
| // Run final optimization | |||
| FinalOptimize(graph); | |||
| // Graph kernel fusion optimization | |||
| GraphKernelOptimize(graph); | |||
| // Start gpu kernel runtime | |||
| @@ -17,6 +17,7 @@ | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include <queue> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| @@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) { | |||
| } | |||
| } | |||
| void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const { | |||
| MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id(); | |||
| opt::CommonFinalOptimization(graph); | |||
| MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id(); | |||
| } | |||
| void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| #ifdef ENABLE_DUMP_IR | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| @@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void UpdateOutputTensors(const VectorRef *outputs, | |||
| const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node); | |||
| virtual void UnifyMindIR(const KernelGraphPtr &graph) {} | |||
| virtual void FinalOptimize(const KernelGraphPtr &graph) const; | |||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||
| virtual void BuildGraphImpl(GraphId) {} | |||
| @@ -32,6 +32,7 @@ | |||
| #include "pipeline/jit/parse/parse_base.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "pipeline/jit/static_analysis/order_enforce.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "pipeline/jit/static_analysis/program_specialize.h" | |||
| @@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool OrderEnforceAction(const ResourcePtr &res) { | |||
| if (res->manager() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null"; | |||
| } | |||
| auto func_graph = res->func_graph(); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null"; | |||
| } | |||
| pipeline::OrderEnforce(func_graph); | |||
| return true; | |||
| } | |||
| bool InferenceOptPrepareAction(const ResourcePtr &res) { | |||
| if (res->manager() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; | |||
| @@ -752,6 +765,7 @@ std::vector<ActionItem> GePipeline() { | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | |||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | |||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| return actions; | |||
| } | |||
| @@ -765,6 +779,8 @@ std::vector<ActionItem> VmPipeline() { | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); | |||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| @@ -784,6 +800,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| std::vector<ActionItem> PServerPipeline() { | |||
| auto actions = CommonPipeline(); | |||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| actions.emplace_back(std::make_pair("pserver", StartPSServerAction)); | |||
| return actions; | |||
| @@ -0,0 +1,258 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/order_enforce.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <queue> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <utility> | |||
| #include "base/core_ops.h" | |||
| namespace mindspore::pipeline { | |||
| namespace { | |||
| class OrderEnforcer { | |||
| public: | |||
| explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) { | |||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| } | |||
| ~OrderEnforcer() = default; | |||
| void Run() { | |||
| auto nodes = MakeTopoSortMap(); | |||
| for (auto &node : nodes) { | |||
| HandleNode(node); | |||
| } | |||
| } | |||
| private: | |||
| AnfNodePtrList MakeTopoSortMap() { | |||
| auto nodes = TopoSort(func_graph_->get_return()); | |||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||
| topo_sort_map_.emplace(nodes[i], i); | |||
| } | |||
| return nodes; | |||
| } | |||
| void HandleNode(const AnfNodePtr &node) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) { | |||
| // Skip nodes other than UpdateState. | |||
| return; | |||
| } | |||
| auto update_state = node->cast<CNodePtr>(); | |||
| if (!HasAbstractUMonad(update_state->input(1))) { | |||
| // Skip UpdateStates for IO. | |||
| return; | |||
| } | |||
| auto updated_refs = FindUpdatedRefs(update_state); | |||
| if (updated_refs.empty()) { | |||
| // Skip UpdateStates that do not have updated refs. | |||
| return; | |||
| } | |||
| auto &attach = update_state->input(2); | |||
| if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { | |||
| // Handle UpdateState with Load. | |||
| EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs); | |||
| } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { | |||
| // Handle UpdateState with MakeTuple. | |||
| EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs); | |||
| } | |||
| } | |||
| std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) { | |||
| std::unordered_set<AnfNodePtr> updated_refs; | |||
| auto &users = manager_->node_users()[update_state]; | |||
| for (auto &user : users) { | |||
| auto cnode = dyn_cast<CNode>(user.first); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) || | |||
| cnode->IsApply(prim::kPrimUpdateState)) { | |||
| continue; | |||
| } | |||
| for (auto &input : cnode->inputs()) { | |||
| if (IsRef(input)) { | |||
| updated_refs.insert(input); | |||
| } | |||
| } | |||
| } | |||
| return updated_refs; | |||
| } | |||
| bool IsRef(const AnfNodePtr &node) { | |||
| auto &abs = node->abstract(); | |||
| return abs != nullptr && abs->isa<abstract::AbstractRef>(); | |||
| } | |||
| void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load, | |||
| const std::unordered_set<AnfNodePtr> &refs) { | |||
| if (refs.find(load->input(1)) == refs.end()) { | |||
| // Skip if loaded parameter is not updated. | |||
| return; | |||
| } | |||
| // Find load users, ignore processed nodes. | |||
| auto load_users = FindLoadUsers(load, update_state); | |||
| // Find load users that not depend on the UpdateState, | |||
| // and than let UpdateState depend on them. | |||
| AddInputEdges(update_state, load_users); | |||
| } | |||
| void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||
| const std::unordered_set<AnfNodePtr> &refs) { | |||
| // The UpdateState should be the only one user of the make_tuple. | |||
| // for performance, we only check the number of output edges. | |||
| if (manager_->node_users()[make_tuple].size() != 1) { | |||
| return; | |||
| } | |||
| // Find load users from the tuple of Load nodes. | |||
| std::unordered_set<AnfNodePtr> all_load_users; | |||
| auto &inputs = make_tuple->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto &input = inputs.at(i); | |||
| if (!IsPrimitiveCNode(input, prim::kPrimLoad)) { | |||
| // Skip non-Load nodes. | |||
| continue; | |||
| } | |||
| auto load = input->cast<CNodePtr>(); | |||
| if (refs.find(load->input(1)) == refs.end()) { | |||
| // Skip if loaded parameter is not updated. | |||
| continue; | |||
| } | |||
| auto load_users = FindLoadUsers(load, make_tuple); | |||
| all_load_users.insert(load_users.begin(), load_users.end()); | |||
| } | |||
| // Find load users that not depend on the UpdateState, | |||
| // and than let UpdateState depend on them. | |||
| AddInputEdges(update_state, all_load_users); | |||
| } | |||
| // Add load users as input edges of the update_state node. | |||
| void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) { | |||
| auto sorted_load_users = SortLoadUsers(load_users); | |||
| for (auto &load_user : sorted_load_users) { | |||
| if (!IsDependOn(load_user, update_state)) { | |||
| processed_nodes_.insert(load_user); | |||
| manager_->AddEdge(update_state, load_user); | |||
| } | |||
| } | |||
| } | |||
| // Sort load users by their topo sort order. | |||
| std::vector<AnfNodePtr> SortLoadUsers(const std::unordered_set<AnfNodePtr> &load_users) { | |||
| std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()}; | |||
| std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); }); | |||
| return vec; | |||
| } | |||
| // Check if the load user node depend on the given UpdateState node. | |||
| bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) { | |||
| size_t update_state_order = topo_sort_map_[update_state]; | |||
| if (topo_sort_map_[load_user] < update_state_order) { | |||
| return false; | |||
| } | |||
| auto user_cnode = dyn_cast<CNode>(load_user); | |||
| if (user_cnode == nullptr) { | |||
| return false; | |||
| } | |||
| size_t seen = NewSeenGeneration(); | |||
| std::queue<CNodePtr> q; | |||
| user_cnode->seen_ = seen; | |||
| q.push(user_cnode); | |||
| while (!q.empty()) { | |||
| auto cnode = q.front(); | |||
| q.pop(); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (input == update_state) { | |||
| // Dependency found. | |||
| return true; | |||
| } | |||
| if (input->seen_ == seen) { | |||
| // Skip visited nodes. | |||
| continue; | |||
| } | |||
| if (topo_sort_map_[input] < update_state_order) { | |||
| // Skip input nodes that before the UpdateState node. | |||
| continue; | |||
| } | |||
| auto input_cnode = dyn_cast<CNode>(input); | |||
| if (input_cnode != nullptr) { | |||
| input_cnode->seen_ = seen; | |||
| q.push(input_cnode); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| return topo_sort_map_[node1] < topo_sort_map_[node2]; | |||
| } | |||
| // Find Load users as the candidate nodes to enforce order of execution. | |||
| std::unordered_set<AnfNodePtr> FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) { | |||
| auto &node_users = manager_->node_users(); | |||
| auto iter = node_users.find(load); | |||
| if (iter == node_users.end()) { | |||
| return {}; | |||
| } | |||
| std::unordered_set<AnfNodePtr> load_users; | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| auto &user_node = user.first; | |||
| if (user_node == exclude) { | |||
| // Skip excluded node. | |||
| continue; | |||
| } | |||
| if (processed_nodes_.find(user_node) != processed_nodes_.end()) { | |||
| // Skip processed nodes. | |||
| continue; | |||
| } | |||
| auto cnode = dyn_cast<CNode>(user_node); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| const bool has_u_input = | |||
| std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); }); | |||
| if (has_u_input) { | |||
| // Skip nodes with memory side effects, which use u input. | |||
| continue; | |||
| } | |||
| load_users.insert(cnode); | |||
| } | |||
| return load_users; | |||
| } | |||
| private: | |||
| const FuncGraphPtr &func_graph_; | |||
| FuncGraphManagerPtr manager_; | |||
| std::unordered_map<AnfNodePtr, size_t> topo_sort_map_; | |||
| std::unordered_set<AnfNodePtr> processed_nodes_; | |||
| }; | |||
| } // namespace | |||
| // | |||
| // Enforce order of execution for Load users node. | |||
| // | |||
| void OrderEnforce(const FuncGraphPtr &func_graph) { | |||
| OrderEnforcer enforcer(func_graph); | |||
| enforcer.Run(); | |||
| } | |||
| } // namespace mindspore::pipeline | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||
| #define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::pipeline { | |||
| // Enforce order of execution of the given graph. | |||
| void OrderEnforce(const FuncGraphPtr &func_graph); | |||
| } // namespace mindspore::pipeline | |||
| #endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ | |||
| @@ -1456,7 +1456,10 @@ def test_while_forward(): | |||
| assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) | |||
| @pytest.mark.skip(reason="not supported yet") | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_multi_add_assign(): | |||
| class Net(Cell): | |||
| def __init__(self, i1): | |||
| @@ -1493,7 +1496,10 @@ def test_multi_add_assign(): | |||
| np.testing.assert_array_equal(outputs, expects) | |||
| @pytest.mark.skip(reason="not supported yet") | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_multi_abs_add_assign(): | |||
| class Net(Cell): | |||
| def __init__(self, para): | |||