|
|
|
@@ -196,7 +196,7 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { |
|
|
|
|
|
|
|
void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map; |
|
|
|
std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map; |
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { |
|
|
|
CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; |
|
|
|
// node has been assigned stream before |
|
|
|
@@ -205,27 +205,52 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
} |
|
|
|
|
|
|
|
if (IsHcom(cur_cnode_ptr)) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) { |
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode_ptr->DebugString() << " has no group attr"; |
|
|
|
} |
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup); |
|
|
|
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); |
|
|
|
auto it = graph_nodes_map.find(hcom_graph_id); |
|
|
|
if (it == graph_nodes_map.end()) { |
|
|
|
auto iter = group_graph_nodes_map.find(group_name); |
|
|
|
if (iter == group_graph_nodes_map.end()) { |
|
|
|
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map; |
|
|
|
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr}; |
|
|
|
group_graph_nodes_map[group_name] = graph_nodes_map; |
|
|
|
} else { |
|
|
|
it->second.emplace_back(cur_cnode_ptr); |
|
|
|
auto &graph_nodes_map = iter->second; |
|
|
|
auto it = graph_nodes_map.find(hcom_graph_id); |
|
|
|
if (it == graph_nodes_map.end()) { |
|
|
|
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr}; |
|
|
|
} else { |
|
|
|
it->second.emplace_back(cur_cnode_ptr); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "hcom diff graph id size:" << graph_nodes_map.size(); |
|
|
|
for (const auto &item : graph_nodes_map) { |
|
|
|
bool new_graph = true; |
|
|
|
auto graph_id = item.first; |
|
|
|
hcom_graph_map_[graph_id] = {}; |
|
|
|
for (const auto &hcom_node_ptr : item.second) { |
|
|
|
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph); |
|
|
|
hcom_graph_map_[graph_id].emplace(assigned_stream_id); |
|
|
|
new_graph = false; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size(); |
|
|
|
for (const auto &item : group_graph_nodes_map) { |
|
|
|
MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size(); |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &diff_group : group_graph_nodes_map) { |
|
|
|
// group id: |
|
|
|
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map; |
|
|
|
for (const auto &item : diff_group.second) { |
|
|
|
bool new_graph = true; |
|
|
|
auto graph_id = item.first; |
|
|
|
hcom_graph_map[graph_id] = {}; |
|
|
|
for (const auto &hcom_node_ptr : item.second) { |
|
|
|
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph); |
|
|
|
hcom_graph_map[graph_id].emplace(assigned_stream_id); |
|
|
|
new_graph = false; |
|
|
|
} |
|
|
|
} |
|
|
|
group_hcom_graph_map_[diff_group.first] = hcom_graph_map; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &item : group_hcom_graph_map_) { |
|
|
|
MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size(); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "hcom stream nums : " << hcom_stream_map_.size(); |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { |
|
|
|
@@ -337,7 +362,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
if (hcom_graph_map_.empty() && independent_graph_map_.empty()) { |
|
|
|
if (group_hcom_graph_map_.empty() && independent_graph_map_.empty()) { |
|
|
|
MS_LOG(INFO) << "Hcom and independent is empty"; |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -347,19 +372,32 @@ void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraph |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Hcom grpah map size:" << hcom_graph_map_.size(); |
|
|
|
std::map<uint32_t, std::set<uint32_t>> other_graph; |
|
|
|
for (const auto &item : hcom_graph_map_) { |
|
|
|
MS_LOG(INFO) << "Graph id:" << item.first; |
|
|
|
if (item.first == root_graph_id) { |
|
|
|
if (loop_sink_) { |
|
|
|
ActiveRootGraphHcom(graph_ptr, item.second); |
|
|
|
std::set<uint32_t> hcom_streams; |
|
|
|
for (const auto &graph_nodes : group_hcom_graph_map_) { |
|
|
|
for (const auto &item : graph_nodes.second) { |
|
|
|
MS_LOG(INFO) << "Graph id:" << item.first; |
|
|
|
if (item.first == root_graph_id) { |
|
|
|
if (loop_sink_) { |
|
|
|
hcom_streams.insert(item.second.begin(), item.second.end()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto it = other_graph.find(item.first); |
|
|
|
if (it == other_graph.end()) { |
|
|
|
other_graph[item.first] = item.second; |
|
|
|
} else { |
|
|
|
for (const auto &stream : item.second) { |
|
|
|
it->second.emplace(stream); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
other_graph[item.first] = item.second; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!hcom_streams.empty()) { |
|
|
|
ActiveRootGraphHcom(graph_ptr, hcom_streams); |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size(); |
|
|
|
for (const auto &item : independent_graph_map_) { |
|
|
|
MS_LOG(DEBUG) << "Graph id:" << item.first; |
|
|
|
@@ -505,7 +543,6 @@ void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr |
|
|
|
independent_stream_activated_ = true; |
|
|
|
graph_ptr->set_execution_order(update_cnode_list); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
MS_LOG(INFO) << "Start"; |
|
|
|
GetProcessedStream(graph_ptr); |
|
|
|
@@ -733,7 +770,7 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { |
|
|
|
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
MS_LOG(INFO) << "Start"; |
|
|
|
InsertEventCommonDependHcom(graph_ptr); |
|
|
|
InsertEventHcomDependCommon(graph_ptr); |
|
|
|
InsertEventHcomDependCommonBak(graph_ptr); |
|
|
|
InsertEventHcomDependHcom(graph_ptr); |
|
|
|
MS_LOG(INFO) << "End"; |
|
|
|
} |
|
|
|
@@ -777,36 +814,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt |
|
|
|
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
const CNodePtr &cur_cnode_ptr) { |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
auto &inputs = cur_cnode_ptr->inputs(); |
|
|
|
auto it_pos = cnode_ptr_list.begin(); |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
if (inputs[i]->isa<CNode>()) { |
|
|
|
auto cnode = inputs[i]->cast<CNodePtr>(); |
|
|
|
while (opt::IsNopNode(cnode)) { |
|
|
|
cnode = cnode->inputs()[1]->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
|
|
|
|
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); |
|
|
|
if (it != cnode_ptr_list.end()) { |
|
|
|
it_pos = it; |
|
|
|
} |
|
|
|
} else { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (it_pos == cnode_ptr_list.begin() && *it_pos != inputs[1]) { |
|
|
|
MS_LOG(EXCEPTION) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "The las input of node:" << cur_cnode_ptr->DebugString() << " is:" << (*it_pos)->fullname_with_scope() |
|
|
|
<< "; name:" << (*it_pos)->DebugString(); |
|
|
|
return *it_pos; |
|
|
|
} |
|
|
|
|
|
|
|
// after memory reuse is correct, use this function |
|
|
|
void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
@@ -830,7 +837,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap |
|
|
|
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); |
|
|
|
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode); |
|
|
|
if (it == cnodes.end()) { |
|
|
|
MS_LOG(ERROR) << "hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) |
|
|
|
MS_LOG(ERROR) << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) |
|
|
|
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes"; |
|
|
|
} else { |
|
|
|
uint32_t cur_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
@@ -848,6 +855,58 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap |
|
|
|
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
const CNodePtr &cur_cnode_ptr) { |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
auto input_cnodes = GetInputKernels(cur_cnode_ptr); |
|
|
|
if (input_cnodes.empty()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto it_pos = cnode_ptr_list.begin(); |
|
|
|
|
|
|
|
for (auto &cnode : input_cnodes) { |
|
|
|
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); |
|
|
|
if (it != cnode_ptr_list.end()) { |
|
|
|
it_pos = it; |
|
|
|
} |
|
|
|
} |
|
|
|
if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) { |
|
|
|
MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; |
|
|
|
} |
|
|
|
|
|
|
|
return *it_pos; |
|
|
|
} |
|
|
|
|
|
|
|
vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &node) { |
|
|
|
vector<CNodePtr> input_cnodes; |
|
|
|
queue<CNodePtr> nop_nodes; |
|
|
|
auto inputs = node->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
auto real_input = AnfAlgo::VisitKernel(inputs[i], 0); |
|
|
|
auto node = real_input.first; |
|
|
|
if (opt::IsNopNode(node)) { |
|
|
|
nop_nodes.push(node->cast<CNodePtr>()); |
|
|
|
while (!nop_nodes.empty()) { |
|
|
|
auto cur_node = nop_nodes.front(); |
|
|
|
nop_nodes.pop(); |
|
|
|
auto new_inputs = cur_node->inputs(); |
|
|
|
for (size_t j = 1; j < new_inputs.size(); j++) { |
|
|
|
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); |
|
|
|
auto new_node = new_real_input.first; |
|
|
|
if (opt::IsNopNode(new_node)) { |
|
|
|
nop_nodes.push(new_node->cast<CNodePtr>()); |
|
|
|
} else if (new_node->isa<CNode>()) { |
|
|
|
input_cnodes.emplace_back(new_node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (node->isa<CNode>()) { |
|
|
|
input_cnodes.emplace_back(node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
} |
|
|
|
return input_cnodes; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
@@ -896,40 +955,70 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt |
|
|
|
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
uint32_t first_hcom_stream = kInvalidStreamId; |
|
|
|
uint32_t last_hcom_stream = kInvalidStreamId; |
|
|
|
// key: stream id, value:hcom index |
|
|
|
std::map<uint32_t, vector<size_t>> hcom_index; |
|
|
|
// key:group id, key: stream id, value:hcom index |
|
|
|
std::map<std::string, std::map<uint32_t, vector<size_t>>> group_hcom_index; |
|
|
|
std::map<std::string, uint32_t> group_first_hcom_stream; |
|
|
|
std::map<std::string, uint32_t> group_last_hcom_stream; |
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) { |
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
if (!IsHcom(cur_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
auto it = hcom_index.find(cur_stream_id); |
|
|
|
if (it != hcom_index.end()) { |
|
|
|
hcom_index[cur_stream_id].emplace_back(i); |
|
|
|
} else { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { |
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; |
|
|
|
} |
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup); |
|
|
|
auto iter = group_hcom_index.find(group_name); |
|
|
|
if (iter == group_hcom_index.end()) { |
|
|
|
std::map<uint32_t, vector<size_t>> hcom_index; |
|
|
|
hcom_index[cur_stream_id] = {i}; |
|
|
|
group_hcom_index[group_name] = hcom_index; |
|
|
|
} else { |
|
|
|
auto &hcom_index = iter->second; |
|
|
|
auto it = hcom_index.find(cur_stream_id); |
|
|
|
if (it == hcom_index.end()) { |
|
|
|
hcom_index[cur_stream_id] = {i}; |
|
|
|
} else { |
|
|
|
it->second.emplace_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// record first hcom stream id |
|
|
|
if (first_hcom_stream == kInvalidStreamId) { |
|
|
|
first_hcom_stream = cur_stream_id; |
|
|
|
auto it = group_first_hcom_stream.find(group_name); |
|
|
|
if (it == group_first_hcom_stream.end()) { |
|
|
|
group_first_hcom_stream[group_name] = cur_stream_id; |
|
|
|
} |
|
|
|
|
|
|
|
// record last hcom stream id |
|
|
|
if (cur_stream_id != last_hcom_stream) { |
|
|
|
last_hcom_stream = cur_stream_id; |
|
|
|
it = group_last_hcom_stream.find(group_name); |
|
|
|
if (it != group_last_hcom_stream.end()) { |
|
|
|
it->second = cur_stream_id; |
|
|
|
} else { |
|
|
|
group_last_hcom_stream[group_name] = cur_stream_id; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (hcom_index.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
return; |
|
|
|
for (const auto &hcom_index : group_hcom_index) { |
|
|
|
if (hcom_index.second.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto group_name = hcom_index.first; |
|
|
|
auto it = group_first_hcom_stream.find(group_name); |
|
|
|
if (it == group_first_hcom_stream.end()) { |
|
|
|
MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name; |
|
|
|
} |
|
|
|
auto first_hcom_stream = it->second; |
|
|
|
|
|
|
|
it = group_last_hcom_stream.find(group_name); |
|
|
|
if (it == group_last_hcom_stream.end()) { |
|
|
|
MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name; |
|
|
|
} |
|
|
|
auto last_hcom_stream = it->second; |
|
|
|
InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream); |
|
|
|
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); |
|
|
|
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
@@ -1199,9 +1288,12 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra |
|
|
|
|
|
|
|
// 3)hcom stream:if has not been activate, push to need active vector |
|
|
|
if (!hcom_stream_activated_) { |
|
|
|
auto it = hcom_graph_map_.find(root_graph_id); |
|
|
|
if (it != hcom_graph_map_.end()) { |
|
|
|
std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_)); |
|
|
|
for (const auto &item : group_hcom_graph_map_) { |
|
|
|
auto &hcom_graph_map = item.second; |
|
|
|
auto it = hcom_graph_map.find(root_graph_id); |
|
|
|
if (it != hcom_graph_map.end()) { |
|
|
|
std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1434,7 +1526,7 @@ void AscendStreamAssign::Reset() { |
|
|
|
event_map_.clear(); |
|
|
|
independent_targets_.clear(); |
|
|
|
independent_graph_map_.clear(); |
|
|
|
hcom_graph_map_.clear(); |
|
|
|
group_hcom_graph_map_.clear(); |
|
|
|
middle_active_streams_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
|