|
|
|
@@ -52,6 +52,7 @@ void ScatterArithmeticCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got " |
|
|
|
<< input_shape.size() << "."; |
|
|
|
} |
|
|
|
input_shape_0 = SizeToInt(input_shape[0]); |
|
|
|
input_size_ = 1; |
|
|
|
inner_size_ = 1; |
|
|
|
if (input_shape.empty()) { |
|
|
|
@@ -92,6 +93,9 @@ bool ScatterArithmeticCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> |
|
|
|
template <typename T> |
|
|
|
void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, const T *updates) const { |
|
|
|
for (size_t i = 0; i < indices_size_; i++) { |
|
|
|
if (indices[i] >= input_shape_0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto base_index_updates = i * inner_size_; |
|
|
|
auto base_index_input = indices[i] * inner_size_; |
|
|
|
for (size_t j = 0; j < inner_size_; j++) { |
|
|
|
@@ -103,6 +107,9 @@ void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, con |
|
|
|
template <typename T> |
|
|
|
void ScatterArithmeticCPUKernel<T>::ScatterSub(T *input, const int *indices, const T *updates) const { |
|
|
|
for (size_t i = 0; i < indices_size_; i++) { |
|
|
|
if (indices[i] >= input_shape_0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto base_index_updates = i * inner_size_; |
|
|
|
auto base_index_input = indices[i] * inner_size_; |
|
|
|
for (size_t j = 0; j < inner_size_; j++) { |
|
|
|
|