From b367da88eb4f4cd96d59e80a95d73466391877e1 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 20 Aug 2020 15:27:09 +0800 Subject: [PATCH] fix ScatterNdUpdate cpu kernel --- .../cpu/scatter_nd_update_cpu_kernel.cc | 29 +++++++------------ .../cpu/scatter_nd_update_cpu_kernel.h | 2 +- .../cpu/unique_with_pad_cpu_kernel.h | 4 +-- mindspore/core/ir/anf.cc | 15 ++++++++-- tests/st/ops/cpu/test_scatter_nd_update_op.py | 18 ++++++------ 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index b8584f8c4c..71c81284a5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -25,9 +25,6 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - if (indices_shape.size() < 2) { - MS_LOG(EXCEPTION) << "Indices' dimension less than 2"; - } auto indices_unit_rank = indices_shape.back(); if (indices_unit_rank > shape.size()) { MS_LOG(EXCEPTION) << "Value of last dimension of indices is greater than shape rank"; @@ -66,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs) { + const std::vector & /*outputs*/) { if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); + LaunchKernel(inputs); } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); + LaunchKernel(inputs); } else { MS_LOG(ERROR) << "Only support float16, float32"; return false; @@ -79,30 +76,26 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inp } template -void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { +void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs) { auto x = reinterpret_cast(inputs[0]->addr); auto indices = reinterpret_cast(inputs[1]->addr); auto updates = reinterpret_cast(inputs[2]->addr); - auto y = reinterpret_cast(outputs[0]->addr); for (int i = 0; i < num_units_; ++i) { int offset = 0; for (int j = 0; j < indices_unit_rank_; ++j) { - offset += indices[i * indices_unit_rank_ + j] * out_strides_[j] * unit_size_; + auto index = indices[i * indices_unit_rank_ + j]; + if (index < 0) { + MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index; + } + offset += index * out_strides_[j] * unit_size_; } output_unit_offsets_[i] = offset; } - auto mem_bits = outputs[0]->size; - auto ret = memcpy_s(y, mem_bits, x, mem_bits); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - } - for (int i = 0; i < num_units_; i++) { - ret = - memcpy_s(y + output_unit_offsets_[i], unit_size_ * sizeof(T), updates + unit_size_ * i, unit_size_ * sizeof(T)); + auto ret = + memcpy_s(x + output_unit_offsets_[i], unit_size_ * sizeof(T), updates + unit_size_ * i, unit_size_ * sizeof(T)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h index b5ee09295a..80b5714bb2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -35,7 +35,7 @@ class ScatterNdUpdateCPUKernel : public CPUKernel { const std::vector &outputs) override; template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + void LaunchKernel(const std::vector &inputs); private: void Check(const CNodePtr &kernel_node); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_with_pad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_with_pad_cpu_kernel.h index d7b7563342..5d063853bf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_with_pad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_with_pad_cpu_kernel.h @@ -39,8 +39,8 @@ class UniqueWithPadCPUKernel : public CPUKernel { private: void CheckParam(const CNodePtr &kernel_node); - int64_t n_; - TypeId dtype_; + int64_t n_{0}; + TypeId dtype_{0}; }; MS_REG_CPU_KERNEL(UniqueWithPad, diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index a681644b3d..89f9eec7e4 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -224,6 +224,14 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { std::string default_target = context_ptr->device_target(); return default_target; } + +std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto input_target = GetCNodeTarget(cnode->input(1)); + primitive->set_attr("primitive_target", MakeValue(input_target)); + return input_target; +} } // namespace std::string GetCNodeTarget(const AnfNodePtr &node) { @@ -256,8 +264,8 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || - IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || - IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { + IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || + IsPrimitive(attr_input, prim::kPrimPartial)) { primitive->EraseAttr("primitive_target"); return default_target; } @@ -273,6 +281,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { return GetMaketupleNodeTarget(cnode); } + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + return GetTupleGetItemTarget(cnode, primitive); + } return default_target; } } // namespace mindspore diff --git a/tests/st/ops/cpu/test_scatter_nd_update_op.py b/tests/st/ops/cpu/test_scatter_nd_update_op.py index e563c55faa..4e0fbf5408 100644 --- a/tests/st/ops/cpu/test_scatter_nd_update_op.py +++ b/tests/st/ops/cpu/test_scatter_nd_update_op.py @@ -64,10 +64,10 @@ def test_op1(): update = Tensor(np.array([1.0, 2.2]), mstype.float32) scatter_nd_update = ScatterNdUpdate1() - output = scatter_nd_update(indices, update) - print("output:\n", output) + scatter_nd_update(indices, update) + print("x:\n", scatter_nd_update.x.default_input) expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]] - assert np.allclose(output.asnumpy(), np.array(expect, np.float)) + assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, np.float)) @pytest.mark.level0 @@ -78,10 +78,10 @@ def test_op2(): update = Tensor(np.array([9, 10, 11, 12]), mstype.float32) scatter_nd_update = ScatterNdUpdate2() - output = scatter_nd_update(indices, update) - print("output:\n", output) + scatter_nd_update(indices, update) + print("x:\n", scatter_nd_update.x.default_input) expect = [1, 11, 3, 10, 9, 6, 7, 12] - assert np.allclose(output.asnumpy(), np.array(expect, dtype=float)) + assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, dtype=float)) @pytest.mark.level0 @@ -95,10 +95,10 @@ def test_op3(): [7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32) scatter_nd_update = ScatterNdUpdate3() - output = scatter_nd_update(indices, update) - print("output:\n", output) + scatter_nd_update(indices, update) + print("x:\n", scatter_nd_update.x.default_input) expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] - assert np.allclose(output.asnumpy(), np.array(expect, dtype=float)) + assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, dtype=float))