Merge pull request !27203 from zhuzhongrui/gmrestags/v1.6.0
| @@ -141,7 +141,11 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||
| template <typename T> | |||
| void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto x = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto ret = memcpy_s(x, outputs[0]->size, inputs[0]->addr, inputs[0]->size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret; | |||
| } | |||
| ComputeParams<T> params; | |||
| params.x_ = x; | |||
| params.indices_ = reinterpret_cast<int *>(inputs[1]->addr); | |||
| @@ -165,11 +169,6 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input | |||
| start += once_compute_size; | |||
| } | |||
| (void)common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -56,12 +56,12 @@ class ScatterNdFunctorKernel : public GpuKernel { | |||
| cudaMemcpyAsync(indices_stride, &out_strides_[0], indices_len, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync failed in ScatterNdFunctorGpuFwdKernel::Launch."); | |||
| CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices, | |||
| updates, input, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync output failed"); | |||
| CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices, | |||
| updates, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -36,6 +36,7 @@ def test_op1(dtype): | |||
| Description: test cases for updating float values | |||
| Expectation: the result match scipy | |||
| """ | |||
| class ScatterNdUpdate(nn.Cell): | |||
| def __init__(self): | |||
| super(ScatterNdUpdate, self).__init__() | |||
| @@ -50,10 +51,10 @@ def test_op1(dtype): | |||
| update = Tensor(np.array([1.0, 2.2], dtype=dtype)) | |||
| scatter_nd_update = ScatterNdUpdate() | |||
| scatter_nd_update(indices, update) | |||
| print("x:\n", scatter_nd_update.x.data.asnumpy()) | |||
| output = scatter_nd_update(indices, update) | |||
| print("x:\n", output.asnumpy()) | |||
| expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]] | |||
| assert np.allclose(scatter_nd_update.x.data.asnumpy(), | |||
| assert np.allclose(output.asnumpy(), | |||
| np.array(expect, dtype=dtype)) | |||
| @@ -67,6 +68,7 @@ def test_op2(dtype): | |||
| Description: test cases for updating int values | |||
| Expectation: the result match scipy | |||
| """ | |||
| class ScatterNdUpdate(nn.Cell): | |||
| def __init__(self): | |||
| super(ScatterNdUpdate, self).__init__() | |||
| @@ -81,10 +83,10 @@ def test_op2(dtype): | |||
| update = Tensor(np.array([9, 10, 11, 12], dtype=dtype)) | |||
| scatter_nd_update = ScatterNdUpdate() | |||
| scatter_nd_update(indices, update) | |||
| print("x:\n", scatter_nd_update.x.data.asnumpy()) | |||
| output = scatter_nd_update(indices, update) | |||
| print("x:\n", output.asnumpy()) | |||
| expect = [1, 11, 3, 10, 9, 6, 7, 12] | |||
| assert np.allclose(scatter_nd_update.x.data.asnumpy(), | |||
| assert np.allclose(output.asnumpy(), | |||
| np.array(expect, dtype=dtype)) | |||
| @@ -98,6 +100,7 @@ def test_op3(dtype): | |||
| Description: test cases for updating int values | |||
| Expectation: the result match scipy | |||
| """ | |||
| class ScatterNdUpdate(nn.Cell): | |||
| def __init__(self): | |||
| super(ScatterNdUpdate, self).__init__() | |||
| @@ -114,13 +117,13 @@ def test_op3(dtype): | |||
| [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype)) | |||
| scatter_nd_update = ScatterNdUpdate() | |||
| scatter_nd_update(indices, update) | |||
| print("x:\n", scatter_nd_update.x.data.asnumpy()) | |||
| output = scatter_nd_update(indices, update) | |||
| print("x:\n", output.asnumpy()) | |||
| expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], | |||
| [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], | |||
| [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], | |||
| [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] | |||
| assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype)) | |||
| assert np.allclose(output.asnumpy(), np.array(expect, dtype=dtype)) | |||
| @pytest.mark.level0 | |||
| @@ -133,6 +136,7 @@ def test_op4(dtype): | |||
| Description: test cases for updating single float value | |||
| Expectation: the result match scipy | |||
| """ | |||
| class ScatterNdUpdate(nn.Cell): | |||
| def __init__(self): | |||
| super(ScatterNdUpdate, self).__init__() | |||
| @@ -148,6 +152,5 @@ def test_op4(dtype): | |||
| scatter_nd_update = ScatterNdUpdate() | |||
| out = scatter_nd_update(indices, update) | |||
| print("x:\n", out) | |||
| assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy()) | |||
| expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]] | |||
| assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype)) | |||
| @@ -92,21 +92,21 @@ def test_scatter_nd_func_input_updated(): | |||
| # update | |||
| net = TestScatterNdFuncNet("update", lock, inputx, indices, updates) | |||
| net() | |||
| output = net() | |||
| expected = np.array([[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]) | |||
| np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| # add | |||
| net = TestScatterNdFuncNet("add", lock, inputx, indices, updates) | |||
| net() | |||
| output = net() | |||
| expected = np.array([[0.9, 0.3, 3.6], [0.4, 2.7, -3.2]]) | |||
| np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| # sub | |||
| net = TestScatterNdFuncNet("sub", lock, inputx, indices, updates) | |||
| net() | |||
| output = net() | |||
| expected = np.array([[-1.1, 0.3, 3.6], [0.4, -1.7, -3.2]]) | |||
| np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||