From d6944a70ca7f9e8e0e925e2980fe84d550c986b5 Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 15 Sep 2020 14:12:24 +0800 Subject: [PATCH] fix cpu kernel:ScatterNdUpdate doesn't set output --- .../cpu/scatter_nd_update_cpu_kernel.cc | 13 +++++++++---- .../cpu/scatter_nd_update_cpu_kernel.h | 2 +- .../ascend/ir_fusion/add_input_to_output.cc | 1 + .../backend/optimizer/pass/add_atomic_clean.cc | 1 + .../pass/const_to_attr_strided_slice_grad.cc | 1 + 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index b5899e008b..bd0ab15541 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -63,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector & /*outputs*/) { + const std::vector &outputs) { if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs); + LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs); + LaunchKernel(inputs, outputs); } else { MS_LOG(ERROR) << "Only support float16, float32"; return false; @@ -76,7 +76,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inp } template -void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs) { +void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { auto x = reinterpret_cast(inputs[0]->addr); auto indices = reinterpret_cast(inputs[1]->addr); auto updates = reinterpret_cast(inputs[2]->addr); @@ -100,6 +101,10 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &input MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } } + auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } } void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h index 4a64061c2a..f2a45cecd8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -35,7 +35,7 @@ class ScatterNdUpdateCPUKernel : public CPUKernel { const std::vector &outputs) override; template - void LaunchKernel(const std::vector &inputs); + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); private: void Check(const CNodePtr &kernel_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc index 52b2deaa4e..cc58d2b057 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc @@ -15,6 +15,7 @@ */ #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" #include +#include #include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/oplib/oplib.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc index 3818e8c171..0c8f04a871 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc @@ -17,6 +17,7 @@ #include "backend/optimizer/pass/add_atomic_clean.h" #include #include +#include #include "base/core_ops.h" #include "utils/utils.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc index 82d1d981ad..b1b25c87ff 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc @@ -16,6 +16,7 @@ #include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "ir/primitive.h" #include "utils/ms_context.h"