From d417dddb242b90ff7fafb5b1e272e171c8c1f5b7 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Tue, 29 Dec 2020 10:19:10 +0800 Subject: [PATCH] enable loop sink when no getnext in execution orders --- .../device/ascend/ascend_stream_assign.cc | 53 ++++- .../device/ascend/ascend_stream_assign.h | 2 + .../ccsrc/runtime/device/kernel_adjust.cc | 198 +++++++++++------- .../ccsrc/runtime/device/kernel_adjust.h | 3 +- mindspore/ccsrc/utils/utils.h | 1 + 5 files changed, 173 insertions(+), 84 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 28bf6f3e2c..83c27ae779 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -762,6 +762,39 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { return false; } +bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) { + auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode); + auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + std::set output_index_set; + // Assign Communicate Op Memory firstly. + for (const auto &node : nodes) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { + continue; + } + if (item_with_index.first == cnode) { + output_index_set.insert(item_with_index.second); + } + } + + MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num + << " outputs, in graph output num:" << output_index_set.size(); + return cnode_out_num == output_index_set.size(); +} + +vector::iterator AscendStreamAssign::FindGraphEnd(vector::iterator begin, + vector::iterator end) { + while (begin != end) { + if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) { + MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope(); + return begin; + } + ++begin; + } + return end; +} + // section5 void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { MS_LOG(INFO) << "Start"; @@ -780,15 +813,23 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNullfullname_with_scope() - << ", can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; + if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) { + // if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode + target = FindGraphEnd(it, cnodes.end()); + } + + if (target == cnodes.end()) { + MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() + << ", can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } } // deal recv op @@ -824,7 +865,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); if (inputs_cnode.empty()) { cnodes.emplace_back(cur_cnode_ptr); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index b0bb6e6a54..8730934c33 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -212,6 +212,8 @@ class AscendStreamAssign { std::map event_map_{}; std::set middle_active_streams_{}; // new policy end + bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode); + vector::iterator FindGraphEnd(vector::iterator begin, vector::iterator end); }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index 7b8360ba97..25d6de491f 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "utils/ms_context.h" @@ -104,7 +105,18 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &kernel_graph_ptr) { +bool KernelAdjust::ExistGetNext(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + const std::vector &cnode_list = kernel_graph_ptr->execution_order(); + for (const auto &cnode : cnode_list) { + if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { + return true; + } + } + return false; +} + +bool KernelAdjust::ExistIndependent(const std::shared_ptr &kernel_graph_ptr) { MS_EXCEPTION_IF_NULL(kernel_graph_ptr); const auto &exe_orders = kernel_graph_ptr->execution_order(); for (const auto &node : exe_orders) { @@ -128,8 +140,13 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr MS_LOG(INFO) << "KernelGraph:" << kernel_graph_ptr->graph_id() << " is dynamic shape, skip InsertSwitchLoop"; return; } - bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; - ReorderGetNext(kernel_graph_ptr); + bool exist_getnext = ExistGetNext(kernel_graph_ptr); + bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX && exist_getnext; + MS_LOG(INFO) << "GetNext exist:" << exist_getnext << " End of Sequence mode:" << eos_mode + << " iter num:" << ConfigManager::GetInstance().iter_num(); + if (exist_getnext) { + ReorderGetNext(kernel_graph_ptr); + } std::map switch_loop_input; CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); @@ -159,84 +176,96 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr std::vector getnext_active_streams; std::vector fpbp_active_streams; CNodePtr getnext_cnode; + uint32_t getnext_switch_stream_id = UINT32_MAX; + uint32_t fpbp_start_event_id = UINT32_MAX; + uint32_t eos_start_event_id = UINT32_MAX; uint32_t eos_done_event_id = UINT32_MAX; + size_t i = 0; // getnext loop process - // getnext loop stream switch op - CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch); - MS_EXCEPTION_IF_NULL(getnext_switch_app); - uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); - exec_order.push_back(getnext_switch_app); - - // getnext op - uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); - size_t i = 0; - for (; i < orders.size(); i++) { - auto node = orders[i]; - exec_order.push_back(node); - AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); - if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { - getnext_cnode = node; - break; + if (exist_getnext) { + // getnext loop stream switch op + getnext_switch_stream_id = resource_manager.ApplyNewStream(); + uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); + CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch); + MS_EXCEPTION_IF_NULL(getnext_switch_app); + AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); + // update getnext loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), getnext_switch_app); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); + AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kGetNextStreamSwitch), getnext_switch_app); + exec_order.push_back(getnext_switch_app); + MS_LOG(INFO) << "GetNext loop insert Stream Switch " << getnext_switch_app->fullname_with_scope(); + + // getnext op + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + getnext_cnode = node; + break; + } } - } - - // update getnext loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), getnext_switch_app); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); - AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kGetNextStreamSwitch), getnext_switch_app); - // getnext loop fpbp start send - uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); - exec_order.push_back(fpbp_start_send); + // getnext loop fpbp start send + fpbp_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); + exec_order.push_back(fpbp_start_send); + MS_LOG(INFO) << "GetNext loop insert FpBp start Send " << fpbp_start_send->fullname_with_scope(); + + if (eos_mode) { + // getnext loop eos start send + eos_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); + exec_order.push_back(eos_start_send); + MS_LOG(INFO) << "GetNext loop insert EoS start Send " << eos_start_send->fullname_with_scope(); + } + } + // End Of Sequence loop process if (eos_mode) { - // getnext loop eos start send - uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); - exec_order.push_back(eos_start_send); - - // End Of Sequence loop process // eos loop stream switch + uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); + uint32_t eos_stream_id = resource_manager.ApplyNewStream(); CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch); MS_EXCEPTION_IF_NULL(eos_switch_app); - uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); + // update eos loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); + AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kEosStreamSwitch), eos_switch_app); exec_order.push_back(eos_switch_app); + MS_LOG(INFO) << "EoS loop insert Stream Switch " << eos_switch_app->fullname_with_scope(); // eos loop eos start recv CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); - uint32_t eos_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); exec_order.push_back(eos_start_recv); - - // update eos loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); - AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kEosStreamSwitch), eos_switch_app); + MS_LOG(INFO) << "EoS loop insert EoS Recv " << eos_start_recv->fullname_with_scope(); // EndOfSequence op CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); MS_EXCEPTION_IF_NULL(end_of_sequence_op); AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); exec_order.push_back(end_of_sequence_op); + MS_LOG(INFO) << "EoS loop insert Eos Op " << end_of_sequence_op->fullname_with_scope(); // eos loop eos done send eos_done_event_id = resource_manager.ApplyNewEvent(); CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); exec_order.push_back(eos_done_send); + MS_LOG(INFO) << "EoS loop insert EoS done Send " << eos_done_send->fullname_with_scope(); // eos loop stream active fpbp_active_streams.push_back(eos_switch_stream_id); } - bool exit_independent = ExitIndependent(kernel_graph_ptr); - if (exit_independent) { + bool exist_independent = ExistIndependent(kernel_graph_ptr); + if (exist_independent) { // Independet parallel CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch); @@ -246,68 +275,80 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), independent_switch_app); AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kIndependentStreamSwitch), independent_switch_app); exec_order.push_back(independent_switch_app); + MS_LOG(INFO) << "Independent op loop insert Stream Switch " << independent_switch_app->fullname_with_scope(); } // fpbp loop process // fpbp loop stream switch + uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); + uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch); MS_EXCEPTION_IF_NULL(fpbp_switch_app); - uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); - - exec_order.push_back(fpbp_switch_app); - - // fpbp loop fpbp start recv - CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); - exec_order.push_back(fpbp_start_recv); - // update fpbp loop stream switch true_branch_stream attr AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue(kFpBpStreamSwitch), fpbp_switch_app); + exec_order.push_back(fpbp_switch_app); + MS_LOG(INFO) << "FpBp loop insert Stream Switch " << fpbp_switch_app->fullname_with_scope(); + + if (exist_getnext) { + // fpbp loop fpbp start recv + CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); + exec_order.push_back(fpbp_start_recv); + MS_LOG(INFO) << "FpBp loop insert FpBp start Recv " << fpbp_start_recv->fullname_with_scope(); + } + // next loop AssignAdd CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false); MS_EXCEPTION_IF_NULL(assign_add_one); AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); exec_order.push_back(assign_add_one); + MS_LOG(INFO) << "FpBp loop insert next loop AssignAdd " << assign_add_one->fullname_with_scope(); - // fpbp memcpy + // fpbp getnext output memcpy std::vector memcpy_list; std::vector other_list; - CNodePtr cur_cnode = nullptr; - for (size_t idx = i + 1; idx < orders.size(); idx++) { - cur_cnode = orders[idx]; - if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { - auto pre_node = orders[idx - 1]; - auto pre_kernel_name = AnfAlgo::GetCNodeName(pre_node); - if (pre_kernel_name == kAtomicAddrCleanOpName) { - other_list.pop_back(); - memcpy_list.push_back(pre_node); + if (exist_getnext) { + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + auto pre_node = orders[idx - 1]; + auto pre_kernel_name = AnfAlgo::GetCNodeName(pre_node); + if (pre_kernel_name == kAtomicAddrCleanOpName) { + other_list.pop_back(); + memcpy_list.push_back(pre_node); + } + memcpy_list.emplace_back(cur_cnode); + } else { + other_list.emplace_back(cur_cnode); } - memcpy_list.emplace_back(cur_cnode); - } else { - other_list.emplace_back(cur_cnode); } + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + } else { + other_list = orders; } - (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); - // fpbp loop eos done recv if (eos_mode) { CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); exec_order.push_back(eos_done_recv); + MS_LOG(INFO) << "FpBp loop insert EoS done Recv " << eos_done_recv->fullname_with_scope(); } // stream active to activate getnext loop - CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(getnext_active_app); - getnext_active_streams.push_back(getnext_switch_stream_id); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), - getnext_active_app); - exec_order.push_back(getnext_active_app); + if (exist_getnext) { + CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(getnext_active_app); + getnext_active_streams.push_back(getnext_switch_stream_id); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), + getnext_active_app); + exec_order.push_back(getnext_active_app); + MS_LOG(INFO) << "FpBp loop insert GetNext loop Stream Active " << getnext_active_app->fullname_with_scope(); + } // fpbp loop other ops (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); @@ -315,7 +356,9 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr // current assign add op CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true); MS_EXCEPTION_IF_NULL(cur_assign_add); + AnfAlgo::SetNodeAttr(kAttrFpBpEnd, MakeValue(true), cur_assign_add); exec_order.push_back(cur_assign_add); + MS_LOG(INFO) << "FpBp loop insert current loop AssignAdd " << cur_assign_add->fullname_with_scope(); // stream active to activate fpbp loop and eos loop CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); @@ -323,6 +366,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr fpbp_active_streams.push_back(fpbp_switch_stream_id); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); exec_order.push_back(fpbp_active_app); + MS_LOG(INFO) << "FpBp loop insert FpBp loop and Eos loop Stream Active " << fpbp_active_app->fullname_with_scope(); kernel_graph_ptr->set_execution_order(exec_order); } diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.h b/mindspore/ccsrc/runtime/device/kernel_adjust.h index bb0cbc6146..c4d34696ed 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.h +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.h @@ -86,7 +86,8 @@ class KernelAdjust { void LoadSwitchInputs(std::vector *inputs); void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, NotNull kernel_graph_ptr); - bool ExitIndependent(const std::shared_ptr &graph_ptr); + bool ExistIndependent(const std::shared_ptr &kernel_graph_ptr); + bool ExistGetNext(const std::shared_ptr &kernel_graph_ptr); }; } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 78960e2b42..b6ffdbd9f2 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -315,6 +315,7 @@ constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; +constexpr auto kAttrFpBpEnd = "fpbp_end"; constexpr auto kAttrFusion = "fusion"; constexpr auto kAttrGroup = "group"; constexpr auto kAttrOp = "op";