|
|
|
@@ -55,8 +55,9 @@ std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = fal |
|
|
|
axis_vec.push_back(i); |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::transform(axis_vec.begin(), axis_vec.end(), axis_vec.begin(), |
|
|
|
[&src_shape_vec](int64_t axis) -> int64_t { return axis < 0 ? axis + src_shape_vec.size() : axis; }); |
|
|
|
std::transform(axis_vec.begin(), axis_vec.end(), axis_vec.begin(), [&src_shape_vec](int64_t axis) -> int64_t { |
|
|
|
return axis < 0 ? axis + SizeToLong(src_shape_vec.size()) : axis; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
std::set<int64_t> axis_set(axis_vec.begin(), axis_vec.end()); |
|
|
|
@@ -113,7 +114,6 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) { |
|
|
|
// Rule: Only one ReduceSum inside sub-graph. |
|
|
|
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex); |
|
|
|
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) { |
|
|
|
AnfNodePtrList reduce_ops; |
|
|
|
size_t reduce_cnt = 0; |
|
|
|
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
@@ -176,8 +176,8 @@ bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) { |
|
|
|
// it is suitable for atomic add only the reduce num is greater than or equal to 1024. |
|
|
|
if (axis_set.count(src_shape_vec.size() - 1) != 0) { |
|
|
|
size_t reduce_size = |
|
|
|
std::accumulate(axis_set.begin(), axis_set.end(), 1, |
|
|
|
[&src_shape_vec](size_t size, int64_t axis) { return size * src_shape_vec[axis]; }); |
|
|
|
std::accumulate(axis_set.begin(), axis_set.end(), LongToSize(1), |
|
|
|
[&src_shape_vec](size_t size, int64_t axis) { return size * LongToSize(src_shape_vec[axis]); }); |
|
|
|
return reduce_size >= 1024; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -314,7 +314,7 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra |
|
|
|
sub_graph->set_output(new_out_node); |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) { |
|
|
|
void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) const { |
|
|
|
// If there is only one output(ReduceSum), it should be a fake output with the same abstract with origin output. |
|
|
|
if (real_output_num_ <= 1) { |
|
|
|
return; |
|
|
|
@@ -333,8 +333,7 @@ void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) { |
|
|
|
composite_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_out_specs)); |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, |
|
|
|
const FuncGraphManagerPtr &mng) { |
|
|
|
void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { |
|
|
|
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); |
|
|
|
auto mng_sub = sub_graph->manager(); |
|
|
|
if (mng_sub == nullptr) { |
|
|
|
@@ -367,7 +366,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, |
|
|
|
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { |
|
|
|
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const { |
|
|
|
// Create depend node to hold execution order. |
|
|
|
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; |
|
|
|
auto depend_cnode = main_graph->NewCNode(d_inputs); |
|
|
|
@@ -376,10 +375,11 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo |
|
|
|
|
|
|
|
auto user_cnode = user_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, depend_cnode); |
|
|
|
user_cnode->set_input(IntToSize(index), depend_cnode); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) { |
|
|
|
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, |
|
|
|
const CNodePtr &composite_node) const { |
|
|
|
// Insert update_state_node, need mount a monad node. |
|
|
|
auto u = NewValueNode(kUMonad); |
|
|
|
u->set_abstract(kUMonad->ToAbstract()); |
|
|
|
@@ -444,7 +444,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP |
|
|
|
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph, |
|
|
|
const AnfNodePtr &composite_node, |
|
|
|
const FuncGraphManagerPtr &mng, |
|
|
|
bool correct_index) { |
|
|
|
bool correct_index) const { |
|
|
|
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes; |
|
|
|
if (real_output_num_ <= 1) { |
|
|
|
auto users = mng->node_users()[composite_node]; |
|
|
|
@@ -469,7 +469,7 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs |
|
|
|
} else if (correct_index) { |
|
|
|
if (real_output_num_ > 2) { |
|
|
|
// Recorrect other getitem index. |
|
|
|
int64_t new_item_idx = CalNewIndex(item_idx, reduce_real_output_index_); |
|
|
|
int64_t new_item_idx = CalNewIndex(item_idx, SizeToLong(reduce_real_output_index_)); |
|
|
|
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node, |
|
|
|
NewValueNode(new_item_idx)}; |
|
|
|
auto new_out = main_graph->NewCNode(new_inputs); |
|
|
|
@@ -477,13 +477,13 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs |
|
|
|
for (const auto &[user, index] : mng->node_users()[get_item_cnode]) { |
|
|
|
auto user_cnode = user->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, new_out); |
|
|
|
user_cnode->set_input(IntToSize(index), new_out); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (const auto &[user, index] : mng->node_users()[node_index.first]) { |
|
|
|
auto user_cnode = user->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, composite_node); |
|
|
|
user_cnode->set_input(IntToSize(index), composite_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -512,7 +512,7 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra |
|
|
|
main_graph->AddNode(load_node); |
|
|
|
auto user_cnode = user_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode); |
|
|
|
user_cnode->set_input(index, load_node); |
|
|
|
user_cnode->set_input(IntToSize(index), load_node); |
|
|
|
to_process_order_.emplace_back(composite_node, user_node); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -529,7 +529,7 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c |
|
|
|
|
|
|
|
// Insert extra input(broadcast node output) to composite node, and make Reducesum inplaceassign to it. |
|
|
|
// Note: if it's single output, this will increase total memory because of a fake out. |
|
|
|
ProcessOriginCNode(origin_composite_node, broadcast_to_node, mng); |
|
|
|
ProcessOriginCNode(origin_composite_node, broadcast_to_node); |
|
|
|
|
|
|
|
// Insert update_state_node to keep execution order. |
|
|
|
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node); |
|
|
|
|