diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 557f083d02..ebcbd9dd35 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1094,10 +1094,15 @@ void KernelGraph::PrintGraphExecuteOrder() const { } } + std::string group_str; + if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL && AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) { + group_str = ", group[" + AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrGroup) + "]"; + } + MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" - << event_str << label_str << active_stream_str; + << event_str << label_str << active_stream_str << group_str; } } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index fd1a18ff97..7cfe92d0c7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -22,6 +22,8 @@ #include "ir/manager.h" #include "utils/ms_context.h" #include "utils/ms_utils.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/device_manager.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_adjust.h" #include "backend/optimizer/common/helper.h" @@ -36,6 +38,79 @@ namespace mindspore { namespace device { namespace ascend { namespace { +constexpr uint32_t kDeviceNumOfServer = 8; +constexpr uint32_t kDeviceNumThreshold = 1024; + +constexpr uint32_t kMaxStreamNum = 1024; +constexpr uint32_t kHcomSecondaryStreamNum = 3; + +constexpr uint32_t kMaxTaskNumPerStream = 1010; +constexpr uint32_t kMaxCommonNodeNumPerStream = 350; + +constexpr uint32_t kTaskNumPerHcomNode = 200; +constexpr uint32_t kTaskNumPerWorldHcomNode = 250; +constexpr uint32_t kTaskNumPerSameServerHcomNode = 125; +constexpr uint32_t kTaskNumPerHcomSendRecvNode = 15; + +bool IsSameServer(const std::vector &rank_ids) { + auto min_iter = min_element(rank_ids.begin(), rank_ids.end()); + uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0; + auto max_iter = max_element(rank_ids.begin(), rank_ids.end()); + uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0; + return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer)); +} + +string GetHcomGroup(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { + MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute."; + } + + return AnfAlgo::GetNodeAttr(cnode, kAttrGroup); +} + +uint32_t GetHcomTaskNum(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { + MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute."; + } + + if (parallel::g_device_manager == nullptr) { + MS_LOG(INFO) << "Device manager is nullptr."; + return kTaskNumPerHcomNode; + } + + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == kHcomSendOpName || node_name == kReceiveOpName) { + return kTaskNumPerHcomSendRecvNode; + } + + MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); + auto device_num = parallel::ParallelContext::GetInstance()->device_num(); + auto group_name = AnfAlgo::GetNodeAttr(cnode, kAttrGroup); + auto group_info = parallel::g_device_manager->group_info(); + for (const auto &info : group_info) { + if (info.first != group_name) { + continue; + } + const auto &rank_ids = info.second; + if (IsSameServer(rank_ids)) { + return kTaskNumPerSameServerHcomNode; + } else if (rank_ids.size() == static_cast(device_num) && device_num >= kDeviceNumThreshold) { + return kTaskNumPerWorldHcomNode; + } else { + return kTaskNumPerHcomNode; + } + } + + // world group is not in group_info. + if (device_num >= kDeviceNumThreshold) { + return kTaskNumPerWorldHcomNode; + } else { + return kTaskNumPerHcomNode; + } +} + CNodePtr GetHcomAndOverflowMarker(const NotNull &graph_ptr, vector *hcom_nodes) { auto cnode_ptr_list = graph_ptr->execution_order(); CNodePtr overflow_marker = nullptr; @@ -90,9 +165,6 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u } } // namespace -const uint32_t kHcomMaxTask = 4; -const uint32_t kCommonMaxTask = 350; - void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) { Reset(); @@ -110,6 +182,10 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) AdjustAtomicAddrCleanOrder(graph_ptr); GetNeedActiveStreams(graph_ptr); + + MS_LOG(INFO) << "Before check resource assign"; + graph_ptr->PrintGraphExecuteOrder(); + CheckResourceAssign(graph_ptr); MS_LOG(INFO) << "After finish stream assign"; #ifdef ENABLE_DUMP_IR @@ -478,15 +554,26 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra AssignCommonStreamId(cur_cnode_ptr); } - MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); + + auto common_stream_num = resource_manager.get_cur_stream_num(); if (exit_hcom) { AssignHcom(graph_ptr); } + auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num; if (exit_independent) { AssignIndependent(graph_ptr); } + auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num; + auto total_stream_num = resource_manager.get_cur_stream_num() + hcom_stream_num * kHcomSecondaryStreamNum; + MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num + << ", hcom stream number: " << hcom_stream_num << "*" << kHcomSecondaryStreamNum + 1 + << ", independent stream number: " << independent_stream_num << "."; + + if (total_stream_num > kMaxStreamNum) { + MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << kMaxStreamNum << "."; + } MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); } @@ -507,7 +594,7 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); } else { - if (it->second < kCommonMaxTask) { + if (it->second < kMaxCommonNodeNumPerStream) { AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); it->second++; } else { @@ -529,10 +616,7 @@ void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { } if (IsHcom(cur_cnode_ptr)) { - if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) { - MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode_ptr->DebugString() << " has no group attr"; - } - auto group_name = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrGroup); + auto group_name = GetHcomGroup(cur_cnode_ptr); auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); auto iter = group_graph_nodes_map.find(group_name); if (iter == group_graph_nodes_map.end()) { @@ -576,6 +660,8 @@ void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto task_num = GetHcomTaskNum(cur_cnode_ptr); + uint32_t cur_hcom_stream_id; if (new_graph) { cur_hcom_stream_id = resource_manager.ApplyNewStream(); @@ -585,15 +671,15 @@ uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, b auto it = hcom_stream_map_.find(cur_hcom_stream_id); if (it == hcom_stream_map_.end()) { AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + hcom_stream_map_.emplace(cur_hcom_stream_id, task_num); } else { - if (it->second < kHcomMaxTask) { + if (it->second <= kMaxTaskNumPerStream - task_num) { AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; + it->second += task_num; } else { cur_hcom_stream_id = resource_manager.ApplyNewStream(); AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + hcom_stream_map_.emplace(cur_hcom_stream_id, task_num); } } return cur_hcom_stream_id; @@ -646,7 +732,7 @@ uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get()); independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1)); } else { - if (it->second < kCommonMaxTask) { + if (it->second < kMaxCommonNodeNumPerStream) { AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); it->second++; } else { @@ -956,7 +1042,7 @@ void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull AscendStreamAssign::GetLastInputCnode(const NotNull &graph_ptr, const CNodePtr &cur_cnode_ptr) { auto cnode_ptr_list = graph_ptr->execution_order(); - auto group_name = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrGroup); + auto group_name = GetHcomGroup(cur_cnode_ptr); auto input_cnodes = GetInputKernels(cur_cnode_ptr); if (input_cnodes.empty()) { return {}; @@ -1256,7 +1342,7 @@ vector AscendStreamAssign::GetLastInputCnode(const NotNull(item.second.first, kAttrGroup); + auto cur_group = GetHcomGroup(item.second.first); if (cur_group == group_name) { continue; } else { @@ -1368,10 +1454,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull } uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { - MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; - } - auto group_name = AnfAlgo::GetNodeAttr(cur_cnode, kAttrGroup); + auto group_name = GetHcomGroup(cur_cnode); MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name << "; stream id:" << cur_stream_id; if (group_name != group) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index d62f900467..a5138e4e1f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -184,7 +184,7 @@ class AscendStreamAssign { void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); void SetLoopSink(); - // function for memory resue + // function for memory reuse void GetStreamRelations(); void DFS(uint32_t start, std::vector *group); bool IsVecExist(const std::vector &group); @@ -202,10 +202,14 @@ class AscendStreamAssign { bool independent_stream_activated_{false}; bool hcom_stream_activated_{false}; bool loop_sink_{false}; - // key:stream id, value:task nums; + + // key:stream id, value:node number + std::map common_stream_map_{}; + // key:stream id, value:node number std::map independent_stream_map_{}; + // key:stream id, value:task number std::map hcom_stream_map_{}; - std::map common_stream_map_{}; + std::set processed_streams_{}; std::vector need_first_active_streams_{}; std::set independent_targets_;