diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index bfc0809a11..9959a65966 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -104,7 +104,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull continue; } - auto res = FindTargetOp(begin, end, cur_independent); + auto res = FindTargetOp(begin, end, cur_independent, false); if (res != end) { flag = true; exe_orders.emplace_back(cur_independent); @@ -247,10 +247,6 @@ void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { } group_hcom_graph_map_[diff_group.first] = hcom_graph_map; } - - for (const auto &item : group_hcom_graph_map_) { - MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size(); - } } uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { @@ -787,7 +783,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNullfullname_with_scope() << ", can't find target for insert recv op, no insert send/recv"; @@ -795,11 +791,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); + if (inputs_cnode.empty()) { + MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes"; + continue; + } + + MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) + << "; inputs cnode size:" << inputs_cnode.size(); + + for (size_t j = 0; j < inputs_cnode.size(); j++) { + auto &cur_input = inputs_cnode.at(j); + MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input); uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id); + auto pre_stream_id = AnfAlgo::GetStreamId(cur_input); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); + auto it = std::find(cnodes.begin(), cnodes.end(), cur_input); + if (it == cnodes.end()) { + MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) + << " can't find input node:" << AnfAlgo::GetCNodeName(cur_input); + } cnodes.insert(it + 1, send); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); @@ -855,26 +857,56 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull &graph_ptr, - const CNodePtr &cur_cnode_ptr) { +vector AscendStreamAssign::GetLastInputCnode(const NotNull &graph_ptr, + const CNodePtr &cur_cnode_ptr) { auto cnode_ptr_list = graph_ptr->execution_order(); + auto group_name = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrGroup); auto input_cnodes = GetInputKernels(cur_cnode_ptr); if (input_cnodes.empty()) { - return nullptr; + return {}; + } + // record max index node for each stream + std::map> result; + for (size_t i = 0; i < input_cnodes.size(); i++) { + auto &cur_input = input_cnodes.at(i); + auto stream_id = AnfAlgo::GetStreamId(cur_input); + auto cur_index = GetIndexByKey(graph_ptr, cur_input.get()); + if (cur_index == UINT32_MAX) { + MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph"; + } + auto it = result.find(stream_id); + if (it == result.end()) { + result[stream_id] = std::make_pair(cur_input, cur_index); + } else { + auto max_index = it->second.second; + if (cur_index > max_index) { + result[stream_id] = std::make_pair(cur_input, cur_index); + } + } } - auto it_pos = cnode_ptr_list.begin(); - for (auto &cnode : input_cnodes) { - auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); - if (it != cnode_ptr_list.end()) { - it_pos = it; + vector final_inputs; + uint32_t max = 0; + CNodePtr max_common_cnode = nullptr; + for (const auto &item : result) { + if (IsHcom(item.second.first)) { + auto cur_group = AnfAlgo::GetNodeAttr(item.second.first, kAttrGroup); + if (cur_group == group_name) { + continue; + } else { + final_inputs.emplace_back(item.second.first); + } + } else { + if (item.second.second > max) { + max_common_cnode = item.second.first; + } } } - if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) { - MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; - } - return *it_pos; + if (max_common_cnode != nullptr) { + final_inputs.emplace_back(max_common_cnode); + } + return final_inputs; } vector AscendStreamAssign::GetInputKernels(const CNodePtr &node) { @@ -956,9 +988,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); // key:group id, key: stream id, value:hcom index - std::map>> group_hcom_index; - std::map group_first_hcom_stream; - std::map group_last_hcom_stream; + std::map>>> group_hcom_index; for (size_t i = 0; i < cnode_ptr_list.size(); i++) { auto cur_cnode = cnode_ptr_list[i]; if (!IsHcom(cur_cnode)) { @@ -969,67 +999,60 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; } auto group_name = AnfAlgo::GetNodeAttr(cur_cnode, kAttrGroup); + MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name + << "; stream id:" << cur_stream_id; auto iter = group_hcom_index.find(group_name); if (iter == group_hcom_index.end()) { - std::map> hcom_index; - hcom_index[cur_stream_id] = {i}; + std::vector>> hcom_index; + hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); group_hcom_index[group_name] = hcom_index; } else { auto &hcom_index = iter->second; - auto it = hcom_index.find(cur_stream_id); - if (it == hcom_index.end()) { - hcom_index[cur_stream_id] = {i}; - } else { - it->second.emplace_back(i); + bool exit = false; + for (auto &item : hcom_index) { + if (item.first == cur_stream_id) { + item.second.emplace_back(i); + exit = true; + break; + } + } + if (!exit) { + hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); } } + } - // record first hcom stream id - auto it = group_first_hcom_stream.find(group_name); - if (it == group_first_hcom_stream.end()) { - group_first_hcom_stream[group_name] = cur_stream_id; - } - - // record last hcom stream id - it = group_last_hcom_stream.find(group_name); - if (it != group_last_hcom_stream.end()) { - it->second = cur_stream_id; - } else { - group_last_hcom_stream[group_name] = cur_stream_id; + for (const auto &hcom_index : group_hcom_index) { + MS_LOG(DEBUG) << "Group:" << hcom_index.first; + for (const auto &item : hcom_index.second) { + MS_LOG(DEBUG) << "stream id:" << item.first; + for (const auto &index : item.second) { + MS_LOG(DEBUG) << "hcom index:" << index; + } } } for (const auto &hcom_index : group_hcom_index) { if (hcom_index.second.size() < 2) { - MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; - return; - } - auto group_name = hcom_index.first; - auto it = group_first_hcom_stream.find(group_name); - if (it == group_first_hcom_stream.end()) { - MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name; - } - auto first_hcom_stream = it->second; - - it = group_last_hcom_stream.find(group_name); - if (it == group_last_hcom_stream.end()) { - MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name; + MS_LOG(INFO) << "Group:" << hcom_index.first + << "; different stream hcom size is less than 2, no need insert event between them"; + continue; } - auto last_hcom_stream = it->second; - InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream); + InsertEventBetweenHcom(graph_ptr, hcom_index.second); MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); } } void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, - const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream) { + const std::vector>> &hcom_index) { vector orders; AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); - size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); + size_t first_stream_last_index = hcom_index[0].second.back(); + size_t last_stream_first_index = hcom_index.back().second.front(); + MS_LOG(INFO) << "First stream last index:" << first_stream_last_index + << "; last stream first index:" << last_stream_first_index; std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { auto cur_cnode = cnode_ptr_list[i]; @@ -1049,7 +1072,17 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &g orders.emplace_back(recv); orders.emplace_back(cur_cnode); } else { - auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); + size_t cur_stream_hcom_size = UINT32_MAX; + size_t first_index = UINT32_MAX; + size_t last_index = UINT32_MAX; + for (const auto &item : hcom_index) { + if (item.first == cur_hcom_stream_id) { + cur_stream_hcom_size = item.second.size(); + first_index = item.second.front(); + last_index = item.second.back(); + } + } + if (cur_stream_hcom_size == 1) { auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); orders.emplace_back(recv); @@ -1059,12 +1092,12 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &g orders.emplace_back(send); } else { // current stream, first hcom:add recv op - if (i == hcom_index.at(cur_hcom_stream_id).front()) { + if (i == first_index) { auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); orders.emplace_back(recv); cur_event_id = resource_manager.ApplyNewEvent(); orders.emplace_back(cur_cnode); - } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { + } else if (i == last_index) { // current stream, last hcom:add send op orders.emplace_back(cur_cnode); auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); @@ -1080,19 +1113,19 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &g graph_ptr->set_execution_order(orders); } -bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, - size_t index) { +bool AscendStreamAssign::IsSatisfiedHcom(const std::vector>> &hcom_index, + const CNodePtr &node_ptr, size_t index) { MS_EXCEPTION_IF_NULL(node_ptr); auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); - auto it = hcom_index.find(cur_hcom_stream_id); - if (it == hcom_index.end()) { - return false; - } - auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); - if (iter == hcom_index.at(cur_hcom_stream_id).end()) { - return false; + for (const auto &item : hcom_index) { + if (item.first == cur_hcom_stream_id) { + auto it = std::find(item.second.begin(), item.second.end(), index); + if (it != item.second.end()) { + return true; + } + } } - return true; + return false; } // section6 @@ -1110,7 +1143,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNullfullname_with_scope() << "] can't find target for insert recv op, no insert send/recv"; @@ -1441,7 +1474,8 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull } vector::iterator AscendStreamAssign::FindTargetOp(vector::iterator begin, - vector::iterator end, const CNodePtr &node) { + vector::iterator end, const CNodePtr &node, + bool exclude_hcom) { while (begin != end) { auto inputs = (*begin)->inputs(); for (size_t i = 1; i < inputs.size(); i++) { @@ -1451,16 +1485,22 @@ vector::iterator AscendStreamAssign::FindTargetOp(vector::it auto new_inputs = cnode->inputs(); for (size_t j = 1; j < new_inputs.size(); j++) { auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); + // find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function + // only insert one time if (node == new_real_input.first) { - MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; - return begin; + if (!(exclude_hcom && IsHcom(*begin))) { + MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; + return begin; + } } } } else { auto real_input = AnfAlgo::VisitKernel(input, 0); if (node == real_input.first) { - MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]"; - return begin; + if (!(exclude_hcom && IsHcom(*begin))) { + MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; + return begin; + } } } } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index c8f6979c34..b0bb6e6a54 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "runtime/base.h" #include "runtime/rt_model.h" #include "runtime/stream.h" @@ -149,12 +150,13 @@ class AscendStreamAssign { void InsertEventHcomDependCommon(const NotNull &graph_ptr); void InsertEventHcomDependCommonBak(const NotNull &graph_ptr); void InsertEventHcomDependHcom(const NotNull &graph_ptr); - void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream); + void InsertEventBetweenHcom(const NotNull &graph_ptr, + const std::vector>> &hcom_index); void AdjustAtomicAddrCleanOrder(const NotNull &graph_ptr); - CNodePtr GetLastInputCnode(const NotNull &graph_ptr, const CNodePtr &cur_cnode_ptr); - bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); + vector GetLastInputCnode(const NotNull &graph_ptr, const CNodePtr &cur_cnode_ptr); + bool IsSatisfiedHcom(const std::vector>> &hcom_index, const CNodePtr &node_ptr, + size_t index); void GetProcessedStream(const NotNull &graph_ptr); void GetNeedActiveStreams(const NotNull &graph_ptr); @@ -169,7 +171,7 @@ class AscendStreamAssign { bool IsIndependentNode(const CNodePtr &node_ptr); bool IsProcessedStream(uint32_t stream_id); vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr &node); + const CNodePtr &node, bool exclude_hcom); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); void SetLoopSink(); @@ -199,6 +201,7 @@ class AscendStreamAssign { std::vector need_first_active_streams_{}; std::set independent_targets_; + // key:group name, value:key1:graph id, value1:stream id std::map>> group_hcom_graph_map_; // key:graph id, value:stream set std::map> independent_graph_map_;