|
|
|
@@ -1481,50 +1481,51 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> |
|
|
|
if (group_hcom_graph_map_.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::vector<string> groups; |
|
|
|
for (const auto &item : group_hcom_graph_map_) { |
|
|
|
groups.emplace_back(item.first); |
|
|
|
} |
|
|
|
for (const auto &group : groups) { |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices; |
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) { |
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
if (!IsHcom(cur_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
auto group_name = GetHcomGroup(cur_cnode); |
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name |
|
|
|
<< "; stream id:" << cur_stream_id; |
|
|
|
if (group_name != group) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (const auto &group_item : group_hcom_graph_map_) { |
|
|
|
auto group = group_item.first; |
|
|
|
for (const auto &graph_item : group_item.second) { |
|
|
|
auto graph_id = graph_item.first; |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices; |
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) { |
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
if (!IsHcom(cur_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (stream_indices.empty()) { |
|
|
|
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); |
|
|
|
} else { |
|
|
|
bool exit = false; |
|
|
|
for (auto &item : stream_indices) { |
|
|
|
if (item.first == cur_stream_id) { |
|
|
|
item.second.emplace_back(i); |
|
|
|
exit = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
auto group_name = GetHcomGroup(cur_cnode); |
|
|
|
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get()); |
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name |
|
|
|
<< "; stream id:" << cur_stream_id; |
|
|
|
if (group_name != group || cur_graph_id != graph_id) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!exit) { |
|
|
|
|
|
|
|
if (stream_indices.empty()) { |
|
|
|
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); |
|
|
|
} else { |
|
|
|
bool exit = false; |
|
|
|
for (auto &item : stream_indices) { |
|
|
|
if (item.first == cur_stream_id) { |
|
|
|
item.second.emplace_back(i); |
|
|
|
exit = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!exit) { |
|
|
|
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (stream_indices.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Group:" << group |
|
|
|
<< "; different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
continue; |
|
|
|
if (stream_indices.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Group:" << group |
|
|
|
<< "; different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
InsertEventBetweenHcom(graph_ptr, stream_indices); |
|
|
|
} |
|
|
|
InsertEventBetweenHcom(graph_ptr, stream_indices); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|