|
|
|
@@ -83,7 +83,7 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_ |
|
|
|
if (!is_unknown_shape) { |
|
|
|
group_node[group_id].push_back(node); |
|
|
|
parallel_group.insert(group_id); |
|
|
|
GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id); |
|
|
|
GELOGD("Find group node:%s, group_id=%d", node->GetName().c_str(), group_id); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -116,7 +116,8 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_ |
|
|
|
cur_node = node_vec[i]; |
|
|
|
auto tmp_pre_node = pre_node; |
|
|
|
auto tmp_cur_node = cur_node; |
|
|
|
GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
|
GELOGD("original add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); |
|
|
|
pre_node = cur_node; |
|
|
|
} |
|
|
|
@@ -127,13 +128,13 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_ |
|
|
|
|
|
|
|
void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { |
|
|
|
if (pre_node == cur_node) { |
|
|
|
GELOGI("--- pr_node == cur_node"); |
|
|
|
return; |
|
|
|
} |
|
|
|
const auto &in_node = cur_node->GetInAllNodes(); |
|
|
|
for (const auto &node : in_node) { |
|
|
|
if (pre_node == node) { |
|
|
|
GELOGI("--- pr_node and cur_node have linked"); |
|
|
|
GELOGD("node:%s and node:%s has linked", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -211,7 +212,7 @@ Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, |
|
|
|
auto &tmp = it->second; |
|
|
|
auto &switch_vec = tmp.first; |
|
|
|
const auto &merge_node = tmp.second; |
|
|
|
GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.", |
|
|
|
GELOGD("Find group node: %s in switch node %s and merge node :%s.", |
|
|
|
group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); |
|
|
|
if (merge_node != merge_vec.back()) { |
|
|
|
GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", |
|
|
|
@@ -263,15 +264,15 @@ void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, |
|
|
|
pre_node = pre_itr->second.second; |
|
|
|
for (const auto &switch_node : cur_itr->second.first) { |
|
|
|
AddCtrlEdge(pre_node, switch_node); |
|
|
|
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
switch_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
GELOGI("--- no need add ctrl edge"); |
|
|
|
} |
|
|
|
} else { |
|
|
|
pre_node = pre_itr->second.second; |
|
|
|
AddCtrlEdge(pre_node, cur_node); |
|
|
|
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (cur_itr != node_2_switch_merge.end()) { |
|
|
|
@@ -281,20 +282,24 @@ void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, |
|
|
|
if (pre_id > switch_id) { // special handle for merge and group node |
|
|
|
auto merge_node = cur_itr->second.second; |
|
|
|
AddCtrlEdge(merge_node, pre_node); |
|
|
|
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str()); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", merge_node->GetName().c_str(), |
|
|
|
pre_node->GetName().c_str()); |
|
|
|
} else { |
|
|
|
AddCtrlEdge(pre_node, switch_node); |
|
|
|
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
switch_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
AddCtrlEdge(pre_node, cur_node); |
|
|
|
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) { |
|
|
|
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, |
|
|
|
const std::set<NodePtr> &switch_set2) { |
|
|
|
for (const auto &node1 : switch_set1) { |
|
|
|
for (const auto &node2 : switch_set2) { |
|
|
|
if (node1 == node2) { |
|
|
|
|