|
|
|
@@ -24,6 +24,7 @@ |
|
|
|
#include <set> |
|
|
|
#include <stack> |
|
|
|
#include <string> |
|
|
|
#include <tuple> |
|
|
|
#include <vector> |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "ir/tensor.h" |
|
|
|
@@ -304,27 +305,52 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo |
|
|
|
user_cnode->set_input(index, depend_cnode); |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node, |
|
|
|
const AnfNodePtr &post_node, const FuncGraphManagerPtr &mng) { |
|
|
|
// Collect use dependencies firstly. |
|
|
|
auto post_users = mng->node_users()[post_node]; |
|
|
|
|
|
|
|
AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, |
|
|
|
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) { |
|
|
|
// Create control depend, first input is composite op, second is user |
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), pre_node, post_node}; |
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; |
|
|
|
auto control_depend_cnode = main_graph->NewCNode(cd_inputs); |
|
|
|
main_graph->AddNode(control_depend_cnode); |
|
|
|
|
|
|
|
// Create depend node to hold new control depend node. |
|
|
|
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), post_node, control_depend_cnode}; |
|
|
|
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode}; |
|
|
|
auto depend_cnode = main_graph->NewCNode(d_inputs); |
|
|
|
depend_cnode->set_abstract(post_node->abstract()); |
|
|
|
depend_cnode->set_abstract(patron_node->abstract()); |
|
|
|
main_graph->AddNode(depend_cnode); |
|
|
|
|
|
|
|
for (const auto &[user_node, index] : post_users) { |
|
|
|
auto user_cnode = user_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, depend_cnode); |
|
|
|
return depend_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) { |
|
|
|
auto mng = main_graph->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(main_graph, true); |
|
|
|
main_graph->set_manager(mng); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr patron_node; |
|
|
|
|
|
|
|
auto return_cnode = main_graph->get_return()->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_cnode); |
|
|
|
auto output_node = return_cnode->input(kFirstDataInputIndex); |
|
|
|
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { |
|
|
|
auto output_cnode = output_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode); |
|
|
|
patron_node = output_cnode->input(kFirstDataInputIndex); |
|
|
|
} else { |
|
|
|
patron_node = output_node; |
|
|
|
} |
|
|
|
|
|
|
|
auto &user_nodes = mng->node_users()[patron_node]; |
|
|
|
auto user = user_nodes.begin(); |
|
|
|
return std::make_tuple(patron_node, user->first, user->second); |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, |
|
|
|
int index) { |
|
|
|
auto patron_user_cnode = patron_user->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(patron_user_cnode); |
|
|
|
patron_user_cnode->set_input(index, patron_node); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) { |
|
|
|
@@ -380,14 +406,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP |
|
|
|
kernel::Processor::CUDA); |
|
|
|
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); |
|
|
|
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); |
|
|
|
// mng->AddFuncGraph(new_sub_graph); |
|
|
|
|
|
|
|
return broadcast_to_composite_node; |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, |
|
|
|
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { |
|
|
|
// 1. find users, change getitem index if needed. |
|
|
|
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph, |
|
|
|
const AnfNodePtr &composite_node, |
|
|
|
const FuncGraphManagerPtr &mng, |
|
|
|
bool correct_index) { |
|
|
|
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes; |
|
|
|
if (real_output_num_ <= 1) { |
|
|
|
auto users = mng->node_users()[composite_node]; |
|
|
|
@@ -409,7 +435,7 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra |
|
|
|
auto item_idx = GetValue<int64_t>(value_node->value()); |
|
|
|
if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) { |
|
|
|
getitem_user_nodes.push_back(node_index); |
|
|
|
} else { |
|
|
|
} else if (correct_index) { |
|
|
|
if (real_output_num_ > 2) { |
|
|
|
// Recorrect other getitem index. |
|
|
|
int64_t new_item_idx = CalNewIndex(item_idx, reduce_real_output_index_); |
|
|
|
@@ -431,7 +457,6 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &pair : getitem_user_nodes) { |
|
|
|
// dirctory to find real user. |
|
|
|
auto real_users = mng->node_users()[pair.first]; |
|
|
|
@@ -439,12 +464,16 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return reduce_user_nodes; |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, |
|
|
|
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { |
|
|
|
// 1. find users, change getitem index if needed. |
|
|
|
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes = |
|
|
|
FindOriginCNodeUsers(main_graph, composite_node, mng, true); |
|
|
|
for (const auto &[user_node, index] : reduce_user_nodes) { |
|
|
|
// 2. set ac output as user's input. |
|
|
|
auto user_cnode = user_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, broadcast_to_node); |
|
|
|
// mng->SetEdge(user_node, index, broadcast_to_node); |
|
|
|
// 3. Make sure modified composite node running first. |
|
|
|
// * To not change the origin node's dependency relation, add ControlDepend and Depend node. |
|
|
|
// * For Return node and output node, ControlDepend node will change the order of these two node, which will may |
|
|
|
@@ -452,7 +481,10 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra |
|
|
|
if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) { |
|
|
|
AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index); |
|
|
|
} else { |
|
|
|
AddControlDepend(main_graph, composite_node, user_node, mng); |
|
|
|
auto user_cnode = user_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, broadcast_to_node); |
|
|
|
to_process_order_.emplace_back(composite_node, user_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -473,6 +505,26 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c |
|
|
|
|
|
|
|
// Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user. |
|
|
|
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng); |
|
|
|
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope() |
|
|
|
<< ", clean node: " << broadcast_to_node->fullname_with_scope(); |
|
|
|
} |
|
|
|
|
|
|
|
bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, |
|
|
|
const FuncGraphManagerPtr &mng) { |
|
|
|
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); |
|
|
|
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce |
|
|
|
// node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! |
|
|
|
return std::all_of(reduce_users.cbegin(), reduce_users.cend(), |
|
|
|
[&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { |
|
|
|
auto &user = user_info.first; |
|
|
|
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || |
|
|
|
IsPrimitiveCNode(user, prim::kPrimControlDepend)) && |
|
|
|
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { |
|
|
|
return false; |
|
|
|
} else { |
|
|
|
return true; |
|
|
|
} |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { |
|
|
|
@@ -487,7 +539,8 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { |
|
|
|
bool changed = false; |
|
|
|
auto topo_nodes = TopoSort(kernel_graph->get_return()); |
|
|
|
for (const auto &node : topo_nodes) { |
|
|
|
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node)) { |
|
|
|
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node) || |
|
|
|
!IsExistStructuralObstacle(kernel_graph, node, mng)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
InsertAtomicClean(kernel_graph, node, mng); |
|
|
|
@@ -495,6 +548,14 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
|
|
|
|
if (changed) { |
|
|
|
if (!to_process_order_.empty()) { |
|
|
|
auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph); |
|
|
|
for (const auto &[prior, behind] : to_process_order_) { |
|
|
|
patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node); |
|
|
|
} |
|
|
|
PostprocessForLastPatron(patron_node, patron_user, user_index); |
|
|
|
} |
|
|
|
|
|
|
|
mng->RemoveRoots(); |
|
|
|
mng->KeepRoots({func_graph}); |
|
|
|
} |
|
|
|
|