|
|
|
@@ -180,23 +180,11 @@ bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) { |
|
|
|
auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex); |
|
|
|
auto src_shape_vec = GetShape(input); |
|
|
|
std::set<int64_t> axis_set = GetUniqReduceAxes(node); |
|
|
|
auto dst_shape_vec = AnfAlgo::GetOutputDeviceShape(node, 0); |
|
|
|
|
|
|
|
// case 1: all reduce |
|
|
|
if (src_shape_vec.size() == axis_set.size()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// case 2: non-reduce axes with dimension 1 |
|
|
|
for (size_t i = 0; i < src_shape_vec.size(); ++i) { |
|
|
|
if (axis_set.find(i) == axis_set.end()) { |
|
|
|
if (src_shape_vec[i] != 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
// all reduce |
|
|
|
// non-reduce axes with dimension 1 |
|
|
|
return std::all_of(dst_shape_vec.cbegin(), dst_shape_vec.cend(), [](const size_t &dim) { return dim == 1; }); |
|
|
|
} |
|
|
|
|
|
|
|
void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { |
|
|
|
|