Browse Source

Fixes core dump issue of ScatterAdd and ScatterSub

tags/v1.6.0
huangbo77 4 years ago
parent
commit
b85f063f23
2 changed files with 8 additions and 0 deletions
  1. +7
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc
  2. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.h

+ 7
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc View File

@@ -52,6 +52,7 @@ void ScatterArithmeticCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got "
<< input_shape.size() << ".";
}
input_shape_0 = SizeToInt(input_shape[0]);
input_size_ = 1;
inner_size_ = 1;
if (input_shape.empty()) {
@@ -92,6 +93,9 @@ bool ScatterArithmeticCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr>
template <typename T>
void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, const T *updates) const {
for (size_t i = 0; i < indices_size_; i++) {
if (indices[i] >= input_shape_0) {
continue;
}
auto base_index_updates = i * inner_size_;
auto base_index_input = indices[i] * inner_size_;
for (size_t j = 0; j < inner_size_; j++) {
@@ -103,6 +107,9 @@ void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, con
template <typename T>
void ScatterArithmeticCPUKernel<T>::ScatterSub(T *input, const int *indices, const T *updates) const {
for (size_t i = 0; i < indices_size_; i++) {
if (indices[i] >= input_shape_0) {
continue;
}
auto base_index_updates = i * inner_size_;
auto base_index_input = indices[i] * inner_size_;
for (size_t j = 0; j < inner_size_; j++) {


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.h View File

@@ -49,6 +49,7 @@ class ScatterArithmeticCPUKernel : public CPUKernel {
using TypeComputeFunc = std::function<void(ScatterArithmeticCPUKernel *, T *, const int *, const T *)>;

TypeComputeFunc compute_func_;
int input_shape_0{0};
size_t input_size_{0};
size_t inner_size_{0};
size_t indices_size_{0};


Loading…
Cancel
Save