Browse Source

fixes bug in scatternd

tags/v1.4.0
huangbo77 4 years ago
parent
commit
e970a56eec
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc

+ 7
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc View File

@@ -40,7 +40,12 @@ void Compute(const ComputeParams<S, T> *params, const size_t start, const size_t
}
offset += index * out_strides->at(j) * params->unit_size_;
}
target[offset] += updates[params->unit_size_ * i];
auto task = [&](size_t update_start, size_t update_end) {
for (size_t idx = update_start; idx < update_end; idx++) {
target[offset + idx] += updates[params->unit_size_ * i + idx];
}
};
CPUKernelUtils::ParallelFor(task, params->unit_size_);
}
}
} // namespace
@@ -90,7 +95,7 @@ bool ScatterNdCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inp
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto target = reinterpret_cast<T *>(outputs[0]->addr);
auto target_init = memset_s(target, outputs[0]->size / sizeof(T), static_cast<T>(0.0), outputs[0]->size / sizeof(T));
auto target_init = memset_s(target, outputs[0]->size, static_cast<T>(0.0), outputs[0]->size);
if (target_init != EOK) {
MS_LOG(EXCEPTION) << "ScatterNdCPUKernel Launch task memset failed.";
}


Loading…
Cancel
Save