From 121a6a28d94ade51ad4df066a05a34896d9b23c8 Mon Sep 17 00:00:00 2001 From: He Wei Date: Tue, 27 Apr 2021 11:26:08 +0800 Subject: [PATCH] [auto-monad] Enforce order of exection for Loads user nodes in frontend --- mindspore/_extends/builtin_operations.py | 2 +- .../backend/optimizer/ascend/ascend_helper.cc | 16 +- .../backend/optimizer/ascend/ascend_helper.h | 2 +- .../bnupdate_eltwise_eltwise_fusion_pass.cc | 3 + .../bnupdate_eltwise_fusion_pass.cc | 3 + .../ascend/buffer_fusion/ub_pattern_fusion.cc | 8 +- ...eal_ref_and_split_unsupported_transdata.cc | 15 +- .../ascend/format_type/insert_cast.cc | 17 +- .../ascend/format_type/insert_trans_op.cc | 2 +- .../common/common_backend_optimization.cc | 20 ++ .../common/common_backend_optimization.h | 1 + .../ccsrc/backend/optimizer/common/helper.cc | 20 +- .../eliminate_redundant_output.cc | 13 +- .../optimizer/pass/optimize_dependence.cc | 90 +++--- .../optimizer/pass/optimize_updatestate.cc | 71 +++++ .../optimizer/pass/optimize_updatestate.h | 33 +++ .../backend/session/anf_runtime_algorithm.cc | 10 + .../backend/session/anf_runtime_algorithm.h | 1 + .../ccsrc/backend/session/ascend_session.cc | 1 + .../ccsrc/backend/session/cpu_session.cc | 1 + .../ccsrc/backend/session/gpu_session.cc | 2 + .../ccsrc/backend/session/session_basic.cc | 7 + .../ccsrc/backend/session/session_basic.h | 1 + mindspore/ccsrc/pipeline/jit/action.cc | 17 ++ .../jit/static_analysis/order_enforce.cc | 258 ++++++++++++++++++ .../jit/static_analysis/order_enforce.h | 27 ++ tests/st/auto_monad/test_auto_monad.py | 10 +- 27 files changed, 569 insertions(+), 82 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc create mode 100644 mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h create mode 100644 mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc create mode 100644 mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index d894a4d488..7348e0322a 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -132,7 +132,7 @@ def Depend(value, expr): return value -def UpdateState(monad, expr): +def UpdateState(monad, *exprs): """Implement `UpdateState`.""" return monad diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 3143e2e0b1..b47f59d1f1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & MS_EXCEPTION_IF_NULL(node_with_index.first); auto real_input = node_with_index.first; if (real_input->isa() || real_input->isa()) { - input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); + input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select); MS_EXCEPTION_IF_NULL(input_node); AnfAlgo::SetNodeInput(node, input_node, index); } @@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An return node; } -AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { +AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, + const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node); + for (auto &update_state : update_states) { + manager->SetEdge(update_state.first, update_state.second, node); + } std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; auto kernel_graph = func_graph->cast(); size_t out_num = AnfAlgo::GetOutputTensorNum(node); @@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & return cast; } -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); if (outputs_num == 0) { @@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP return new_node; } // Multiple output - return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); + return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select); } AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index 3253485c87..f03528ea96 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select); -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, const KernelSelectPtr &kernel_select); CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 03bb2cfd94..438cc0dba4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); for (auto out_getitem : manager->node_users()[bnupdate]) { MS_EXCEPTION_IF_NULL(out_getitem.first); + if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { + continue; + } auto out_getitem_ptr = out_getitem.first->cast(); MS_EXCEPTION_IF_NULL(out_getitem_ptr); auto input2 = out_getitem_ptr->input(2); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc index 17bffc7f11..6db793dfd2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); for (auto out_getitem : manager->node_users()[bnupdate]) { MS_EXCEPTION_IF_NULL(out_getitem.first); + if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) { + continue; + } auto out_getitem_ptr = out_getitem.first->cast(); MS_EXCEPTION_IF_NULL(out_getitem_ptr); auto input2 = out_getitem_ptr->input(2); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index a2420eb92f..819637ace7 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, } else { int64_t prev_idx = 0; std::vector tuple_getitem_nodes; - std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), - std::back_inserter(tuple_getitem_nodes), - [](const std::pair &use_node) { return use_node.first; }); + for (auto &user : manager->node_users()[node]) { + if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) { + tuple_getitem_nodes.emplace_back(user.first); + } + } std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); for (auto &getitem : tuple_getitem_nodes) { MS_EXCEPTION_IF_NULL(getitem); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc index d00c002f1f..6b9c64c91d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc @@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get return func_graph->NewCNode(depend_nodes); } CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( - const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr &op_info) const { + const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr &op_info) const { + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto cnode = orig_cnode; + auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); + if (!update_states.empty()) { + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + cnode = kernel_graph->NewCNode(orig_cnode); + cnode->set_inputs(orig_cnode->inputs()); + for (auto &update_state : update_states) { + manager->SetEdge(update_state.first, update_state.second, cnode); + } + } MS_EXCEPTION_IF_NULL(op_info); auto ref_infos = op_info->ref_infos(); std::vector make_tuple_inputs; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index 47d66edace..875b75ba0a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -30,9 +30,16 @@ namespace mindspore { namespace opt { namespace { -AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { +AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, + const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(cnode); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode); + for (auto &update_state : update_states) { + manager->SetEdge(update_state.first, update_state.second, cnode); + } std::vector make_tuple_inputs; AbstractBasePtrList abstract_list; make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); @@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo MS_EXCEPTION_IF_NULL(make_tuple); make_tuple->set_abstract(std::make_shared(abstract_list)); return make_tuple; -} // namespace +} -AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { +AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { @@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c return replace_node; } // Multiple output - return InsertCastForMultipleOutput(func_graph, cnode); + return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode); } } // namespace @@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo kernel_graph->ReplaceInternalOutput(node, new_node); } // process output - return InsertCastForOutput(func_graph, new_node); + return InsertCastForOutput(func_graph, cnode, new_node); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index 0e6accb14e..99c4cb4d8b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { kernel_graph->ReplaceInternalOutput(node, new_node); } - return InsertTransOpForOutput(func_graph, new_node, kernel_select_); + return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc index 90842069ee..0d838b6ac0 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -25,6 +25,7 @@ #include "backend/optimizer/pass/convert_const_scalar_to_tensor.h" #include "backend/optimizer/pass/convert_attr_to_unify_mindir.h" #include "backend/optimizer/pass/add_training_attr.h" +#include "backend/optimizer/pass/optimize_updatestate.h" #include "utils/ms_context.h" #include "debug/anf_ir_dump.h" @@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr &kern DumpIR(file_name, kernel_graph); } } + +void CommonFinalOptimization(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + // Run optimizer passes. + auto optimizer = std::make_shared(); + auto pm = std::make_shared("final_opt"); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + // Dump IR if save_graphs is set. + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + const bool save_graphs = context->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + if (save_graphs) { + std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(filename, kernel_graph); + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h index e673e101bb..f05bfefca8 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h @@ -20,6 +20,7 @@ namespace mindspore { namespace opt { void BackendCommonOptimization(const std::shared_ptr &kernel_graph); +void CommonFinalOptimization(const std::shared_ptr &kernel_graph); } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 2a4f64cbed..c215b10b7c 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -401,11 +401,9 @@ std::shared_ptr>> GetRealNodeUsedList(con } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && - output_info.second == kDependAttachNodeIndex) { - continue; - } - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) { + auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); + if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || + (cnode_name == prim::kPrimUpdateState->name())) { continue; } output_node_list->push_back(output_info); @@ -426,12 +424,13 @@ std::shared_ptr>> GetRealNodeUsedListByOu } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && - output_info.second == kDependAttachNodeIndex) { + auto cnode_name = AnfAlgo::GetCNodeName(output_info.first); + if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) || + (cnode_name == prim::kPrimUpdateState->name())) { continue; } size_t used_output_index; - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) { + if (cnode_name == prim::kPrimTupleGetItem->name()) { used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast(output_info.first)); } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { used_output_index = output_index; @@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - // find BatchNorm's output which is a Depend + // Find BatchNorm's output which is a Depend or UpdateState. for (const auto &node_index : manager->node_users()[old_node]) { AnfNodePtr output = node_index.first; size_t index = IntToSize(node_index.second); MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || + AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { auto depend = output->cast(); MS_EXCEPTION_IF_NULL(depend); depend->set_input(index, new_node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc index d0767d7861..a183371685 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc @@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr auto output_num = output->cast()->size() - 1; getitem_list->clear(); getitem_list->resize(output_num, nullptr); - const auto &users = mng->node_users()[node]; + auto users = mng->node_users()[node]; bool changed = false; - AnfNodePtrList user_nodes; - std::transform(users.begin(), users.end(), std::back_inserter(user_nodes), - [](const std::pair &user) { return user.first; }); - for (const auto &getitem : user_nodes) { - MS_EXCEPTION_IF_NULL(getitem); + for (const auto &user : users) { + if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { + // Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState. + continue; + } + auto &getitem = user.first; auto idx = GetIndex(getitem); if (idx >= output_num) { MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString(); diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index daa4df3b7f..b6ad9e7e81 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno const std::vector &new_depend_inputs) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(cnode); - auto kernel_graph = func_graph->cast>(); - CNodePtr new_depend = nullptr; + auto kernel_graph = func_graph->cast(); if (kernel_graph == nullptr) { - new_depend = func_graph->NewCNode(new_depend_inputs); + auto new_depend = func_graph->NewCNode(new_depend_inputs); MS_EXCEPTION_IF_NULL(new_depend); new_depend->set_abstract(cnode->abstract()); new_depend->set_scope(cnode->scope()); - } else { - new_depend = kernel_graph->NewCNode(cnode); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_inputs(new_depend_inputs); + return new_depend; } - func_graph->manager()->Replace(cnode, new_depend); + auto new_depend = kernel_graph->NewCNode(cnode); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_inputs(new_depend_inputs); return new_depend; } @@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con auto replace_node = eliminate_node->input(kSingleInputIndex); std::vector new_depend_inputs = cnode->inputs(); new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; - auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs); - auto new_node = new_cnode->cast(); - return new_node; + auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs); + func_graph->manager()->Replace(cnode, new_depend); + return new_depend; } AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { @@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const { return VectorRef({X, Xs}); } -std::pair SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) { - if (node == nullptr || !node->isa()) { - return std::pair(nullptr, 0); - } - // get real input of depend and update state. - size_t replace_input_index = 0; - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { - replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend; - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { - replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput; - } else { - return std::pair(nullptr, 0); - } - // check whether real input is cast or trans data - auto real_input = node->cast()->input(replace_input_index); - if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) || - AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) || - AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) { - return std::pair(node, replace_input_index); - } - return SearchTransDataAndCast(real_input, false); +std::vector SearchTransDataAndCast(const CNodePtr &cnode) { + // Search Depend and UpdateState only. + if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) { + return {}; + } + // Find inputs which is Cast or TransData. + std::vector result; + for (size_t i = 1; i < cnode->size(); ++i) { + auto &input = cnode->input(i); + if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) || + AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) || + AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { + result.emplace_back(i); + } + } + return result; } const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { + auto cnode = dyn_cast(node); + if (cnode == nullptr) { return nullptr; } - // Get the cnode with repalce input index - auto cnode_with_input_index = SearchTransDataAndCast(node, true); - if (cnode_with_input_index.first == nullptr) { + // Search inputs to be replaced. + auto candidate_inputs = SearchTransDataAndCast(cnode); + if (candidate_inputs.empty()) { return nullptr; } - size_t replace_index = cnode_with_input_index.second; - auto depend_cnode = cnode_with_input_index.first->cast(); - MS_EXCEPTION_IF_NULL(depend_cnode); - // Get new node which will act as new input of depend or UpdateState. - std::vector new_depend_inputs = depend_cnode->inputs(); - auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index); - if (replace_node == nullptr) { - return nullptr; + // Get new nodes which will act as new inputs of Depend or UpdateState. + std::vector new_inputs = cnode->inputs(); + bool inputs_changed = false; + for (auto index : candidate_inputs) { + auto replace_node = GetConvertNode(func_graph, cnode, index); + if (replace_node != nullptr) { + new_inputs[index] = replace_node; + inputs_changed = true; + } } - new_depend_inputs[replace_index] = replace_node; - auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs); - if (new_depend == nullptr) { + if (!inputs_changed) { return nullptr; } + // Create a new Depend node to replace the old one if inputs changed. + auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs); + func_graph->manager()->Replace(cnode, new_depend); return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc new file mode 100644 index 0000000000..30bff16acb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/optimize_updatestate.h" +#include +#include +#include +#include "base/core_ops.h" +#include "utils/utils.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +constexpr size_t kInputIndex = 1; +constexpr size_t kAttachIndex = 2; +constexpr size_t kAdditionalAttachIndex = 3; + +const BaseRef OptimizeUpdateState::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimUpdateState, Xs}); +} + +const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + auto update_state = dyn_cast(node); + MS_EXCEPTION_IF_NULL(update_state); + if (update_state->size() <= kAdditionalAttachIndex) { + // Skip UpdateState nodes with no additional attaches. + return nullptr; + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + std::vector new_inputs; + new_inputs.emplace_back(update_state->input(0)); + new_inputs.emplace_back(update_state->input(kInputIndex)); + new_inputs.emplace_back(update_state->input(kAttachIndex)); + for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) { + auto &attach = update_state->input(i); + auto &users = node_users[attach]; + if ((users.size() == 1) && (users.front().first == update_state)) { + // If the only user of attach is the UpdateState node, drop the attach node. + continue; + } + new_inputs.emplace_back(attach); + } + if (new_inputs.size() == update_state->size()) { + // Attaches not changed. + return nullptr; + } + // Attaches changed, make a new UpdateState. + auto new_update_state = func_graph->NewCNode(new_inputs); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h b/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h new file mode 100644 index 0000000000..937ba45703 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class OptimizeUpdateState : public PatternProcessPass { + public: + explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {} + ~OptimizeUpdateState() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_ diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index e14d5ba16a..164e0c51e7 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull root_ {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); root_graph->set_output(make_tuple); } + +AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { + AnfNodeIndexSet update_states; + for (auto &user : manager->node_users()[node]) { + if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) { + update_states.insert(user); + } + } + return update_states; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index ad36329985..b95d671d4c 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm { static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited); static void InsertMakeTupleForOutput(NotNull root_graph); + static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 79657959bd..6b62316c8e 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() { void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "HardwareOptimize start!"; opt::AscendBackendOptimization(kernel_graph); + FinalOptimize(kernel_graph); GraphKernelOptimize(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 7d08c7607e..bd7f946697 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr SetKernelInfo(graph.get()); MS_LOG(INFO) << "Set kernel info end"; Optimize(graph); + FinalOptimize(graph); MS_LOG(INFO) << "Build kernel"; BuildKernel(graph.get()); // Remove reorder after PS feature finish adapting push/pull in auto_monad. diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index b537bf63da..9d79ea19dc 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { SelectKernel(graph); // Graph optimization relevant to device data format HardwareOptimize(graph); + // Run final optimization + FinalOptimize(graph); // Graph kernel fusion optimization GraphKernelOptimize(graph); // Start gpu kernel runtime diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 3ce7e07c55..0ec4ae7d54 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) { } } +void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const { + MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id(); + opt::CommonFinalOptimization(graph); + MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id(); +} + void SessionBasic::DumpGraph(const std::shared_ptr &kernel_graph) { #ifdef ENABLE_DUMP_IR auto context_ptr = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index a30d28f3bb..4ddbd5df21 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual void UpdateOutputTensors(const VectorRef *outputs, const std::map &tensor_to_node); virtual void UnifyMindIR(const KernelGraphPtr &graph) {} + virtual void FinalOptimize(const KernelGraphPtr &graph) const; virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } virtual GraphId CompileGraphImpl(NotNull func_graph) { return kInvalidGraphId; } virtual void BuildGraphImpl(GraphId) {} diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index ef07516332..3e4b603572 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -32,6 +32,7 @@ #include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/static_analysis/auto_monad.h" +#include "pipeline/jit/static_analysis/order_enforce.h" #include "abstract/abstract_value.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "pipeline/jit/static_analysis/program_specialize.h" @@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) { return true; } +bool OrderEnforceAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null"; + } + auto func_graph = res->func_graph(); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null"; + } + pipeline::OrderEnforce(func_graph); + return true; +} + bool InferenceOptPrepareAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; @@ -752,6 +765,7 @@ std::vector GePipeline() { // Add opt-stage python pass stub actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); + actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); return actions; } @@ -765,6 +779,8 @@ std::vector VmPipeline() { // Add opt-stage python pass stub actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); + actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) if (ps::PSContext::instance()->is_worker()) { @@ -784,6 +800,7 @@ std::vector VmPipeline() { std::vector PServerPipeline() { auto actions = CommonPipeline(); actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); actions.emplace_back(std::make_pair("pserver", StartPSServerAction)); return actions; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc new file mode 100644 index 0000000000..1d2387fd54 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -0,0 +1,258 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/order_enforce.h" +#include +#include +#include +#include +#include +#include +#include +#include "base/core_ops.h" + +namespace mindspore::pipeline { +namespace { + +class OrderEnforcer { + public: + explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) { + MS_EXCEPTION_IF_NULL(func_graph_); + MS_EXCEPTION_IF_NULL(manager_); + } + ~OrderEnforcer() = default; + + void Run() { + auto nodes = MakeTopoSortMap(); + for (auto &node : nodes) { + HandleNode(node); + } + } + + private: + AnfNodePtrList MakeTopoSortMap() { + auto nodes = TopoSort(func_graph_->get_return()); + for (size_t i = 0; i < nodes.size(); ++i) { + topo_sort_map_.emplace(nodes[i], i); + } + return nodes; + } + + void HandleNode(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + // Skip nodes other than UpdateState. + return; + } + auto update_state = node->cast(); + if (!HasAbstractUMonad(update_state->input(1))) { + // Skip UpdateStates for IO. + return; + } + auto updated_refs = FindUpdatedRefs(update_state); + if (updated_refs.empty()) { + // Skip UpdateStates that do not have updated refs. + return; + } + auto &attach = update_state->input(2); + if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { + // Handle UpdateState with Load. + EnforceOrderForLoad(update_state, attach->cast(), updated_refs); + } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { + // Handle UpdateState with MakeTuple. + EnforceOrderForTuple(update_state, attach->cast(), updated_refs); + } + } + + std::unordered_set FindUpdatedRefs(const CNodePtr &update_state) { + std::unordered_set updated_refs; + auto &users = manager_->node_users()[update_state]; + for (auto &user : users) { + auto cnode = dyn_cast(user.first); + if (cnode == nullptr) { + continue; + } + if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) || + cnode->IsApply(prim::kPrimUpdateState)) { + continue; + } + for (auto &input : cnode->inputs()) { + if (IsRef(input)) { + updated_refs.insert(input); + } + } + } + return updated_refs; + } + + bool IsRef(const AnfNodePtr &node) { + auto &abs = node->abstract(); + return abs != nullptr && abs->isa(); + } + + void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load, + const std::unordered_set &refs) { + if (refs.find(load->input(1)) == refs.end()) { + // Skip if loaded parameter is not updated. + return; + } + // Find load users, ignore processed nodes. + auto load_users = FindLoadUsers(load, update_state); + // Find load users that not depend on the UpdateState, + // and than let UpdateState depend on them. + AddInputEdges(update_state, load_users); + } + + void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, + const std::unordered_set &refs) { + // The UpdateState should be the only one user of the make_tuple. + // for performance, we only check the number of output edges. + if (manager_->node_users()[make_tuple].size() != 1) { + return; + } + // Find load users from the tuple of Load nodes. + std::unordered_set all_load_users; + auto &inputs = make_tuple->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto &input = inputs.at(i); + if (!IsPrimitiveCNode(input, prim::kPrimLoad)) { + // Skip non-Load nodes. + continue; + } + auto load = input->cast(); + if (refs.find(load->input(1)) == refs.end()) { + // Skip if loaded parameter is not updated. + continue; + } + auto load_users = FindLoadUsers(load, make_tuple); + all_load_users.insert(load_users.begin(), load_users.end()); + } + // Find load users that not depend on the UpdateState, + // and than let UpdateState depend on them. + AddInputEdges(update_state, all_load_users); + } + + // Add load users as input edges of the update_state node. + void AddInputEdges(const CNodePtr &update_state, const std::unordered_set &load_users) { + auto sorted_load_users = SortLoadUsers(load_users); + for (auto &load_user : sorted_load_users) { + if (!IsDependOn(load_user, update_state)) { + processed_nodes_.insert(load_user); + manager_->AddEdge(update_state, load_user); + } + } + } + + // Sort load users by their topo sort order. + std::vector SortLoadUsers(const std::unordered_set &load_users) { + std::vector vec{load_users.begin(), load_users.end()}; + std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); }); + return vec; + } + + // Check if the load user node depend on the given UpdateState node. + bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) { + size_t update_state_order = topo_sort_map_[update_state]; + if (topo_sort_map_[load_user] < update_state_order) { + return false; + } + auto user_cnode = dyn_cast(load_user); + if (user_cnode == nullptr) { + return false; + } + size_t seen = NewSeenGeneration(); + std::queue q; + user_cnode->seen_ = seen; + q.push(user_cnode); + while (!q.empty()) { + auto cnode = q.front(); + q.pop(); + for (auto &input : cnode->inputs()) { + if (input == update_state) { + // Dependency found. + return true; + } + if (input->seen_ == seen) { + // Skip visited nodes. + continue; + } + if (topo_sort_map_[input] < update_state_order) { + // Skip input nodes that before the UpdateState node. + continue; + } + auto input_cnode = dyn_cast(input); + if (input_cnode != nullptr) { + input_cnode->seen_ = seen; + q.push(input_cnode); + } + } + } + return false; + } + + bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) { + return topo_sort_map_[node1] < topo_sort_map_[node2]; + } + + // Find Load users as the candidate nodes to enforce order of execution. + std::unordered_set FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) { + auto &node_users = manager_->node_users(); + auto iter = node_users.find(load); + if (iter == node_users.end()) { + return {}; + } + std::unordered_set load_users; + auto &users = iter->second; + for (auto &user : users) { + auto &user_node = user.first; + if (user_node == exclude) { + // Skip excluded node. + continue; + } + if (processed_nodes_.find(user_node) != processed_nodes_.end()) { + // Skip processed nodes. + continue; + } + auto cnode = dyn_cast(user_node); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + const bool has_u_input = + std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); }); + if (has_u_input) { + // Skip nodes with memory side effects, which use u input. + continue; + } + load_users.insert(cnode); + } + return load_users; + } + + private: + const FuncGraphPtr &func_graph_; + FuncGraphManagerPtr manager_; + std::unordered_map topo_sort_map_; + std::unordered_set processed_nodes_; +}; + +} // namespace + +// +// Enforce order of execution for Load users node. +// +void OrderEnforce(const FuncGraphPtr &func_graph) { + OrderEnforcer enforcer(func_graph); + enforcer.Run(); +} +} // namespace mindspore::pipeline diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h new file mode 100644 index 0000000000..1327b44c68 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ + +#include "ir/func_graph.h" + +namespace mindspore::pipeline { +// Enforce order of execution of the given graph. +void OrderEnforce(const FuncGraphPtr &func_graph); +} // namespace mindspore::pipeline + +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_ diff --git a/tests/st/auto_monad/test_auto_monad.py b/tests/st/auto_monad/test_auto_monad.py index 82ebece1bd..d79b5ba7ef 100644 --- a/tests/st/auto_monad/test_auto_monad.py +++ b/tests/st/auto_monad/test_auto_monad.py @@ -1456,7 +1456,10 @@ def test_while_forward(): assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) -@pytest.mark.skip(reason="not supported yet") +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_multi_add_assign(): class Net(Cell): def __init__(self, i1): @@ -1493,7 +1496,10 @@ def test_multi_add_assign(): np.testing.assert_array_equal(outputs, expects) -@pytest.mark.skip(reason="not supported yet") +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_multi_abs_add_assign(): class Net(Cell): def __init__(self, para):