| @@ -40,10 +40,10 @@ class ScatterAddKernel : public GpuKernel { | |||
| int *indices = GetDeviceAddress<int>(inputs, 1); | |||
| T *updates = GetDeviceAddress<T>(inputs, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalScatterAdd(inner_size_, indices_size_, indices, updates, input, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync output failed"); | |||
| CalScatterAdd(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -40,10 +40,10 @@ class ScatterUpdateKernel : public GpuKernel { | |||
| int *indices = GetDeviceAddress<int>(inputs, 1); | |||
| T *updates = GetDeviceAddress<T>(inputs, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalScatterUpdate(inner_size_, indices_size_, indices, updates, input, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync output failed"); | |||
| CalScatterUpdate(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -19,26 +19,26 @@ | |||
| template <typename T> | |||
| __global__ void ScatterAdd(const int inner_size, const int updates_size, const int *indices, const T *updates, | |||
| T *output) { | |||
| T *input) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { | |||
| const size_t index = pos / inner_size; | |||
| const size_t offset = pos % inner_size; | |||
| const size_t current_pos = indices[index] * inner_size + offset; | |||
| MsAtomicAdd(&output[current_pos], updates[pos]); | |||
| MsAtomicAdd(&input[current_pos], updates[pos]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, | |||
| void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, | |||
| cudaStream_t cuda_stream) { | |||
| const int updates_size = inner_size * indices_size; | |||
| ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates, | |||
| output); | |||
| input); | |||
| } | |||
| template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const float *updates, float *output, cudaStream_t cuda_stream); | |||
| const float *updates, float *input, cudaStream_t cuda_stream); | |||
| template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const half *updates, half *output, cudaStream_t cuda_stream); | |||
| const half *updates, half *input, cudaStream_t cuda_stream); | |||
| template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const int *indices, const int *updates, | |||
| int *output, cudaStream_t cuda_stream); | |||
| int *input, cudaStream_t cuda_stream); | |||
| @@ -20,7 +20,7 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, | |||
| void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_ | |||
| @@ -18,31 +18,31 @@ | |||
| template <typename T> | |||
| __global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates, | |||
| T *output) { | |||
| T *input) { | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { | |||
| const int index = pos / inner_size; | |||
| const int offset = pos % inner_size; | |||
| const int current_pos = indices[index] * inner_size + offset; | |||
| output[current_pos] = updates[pos]; | |||
| input[current_pos] = updates[pos]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, | |||
| void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, | |||
| cudaStream_t cuda_stream) { | |||
| const int updates_size = inner_size * indices_size; | |||
| ScatterUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates, | |||
| output); | |||
| input); | |||
| } | |||
| template void CalScatterUpdate<float>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const float *updates, float *output, cudaStream_t cuda_stream); | |||
| const float *updates, float *input, cudaStream_t cuda_stream); | |||
| template void CalScatterUpdate<half>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const half *updates, half *output, cudaStream_t cuda_stream); | |||
| const half *updates, half *input, cudaStream_t cuda_stream); | |||
| template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const int *updates, int *output, cudaStream_t cuda_stream); | |||
| const int *updates, int *input, cudaStream_t cuda_stream); | |||
| template void CalScatterUpdate<unsigned char>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const unsigned char *updates, unsigned char *output, | |||
| const unsigned char *updates, unsigned char *input, | |||
| cudaStream_t cuda_stream); | |||
| template void CalScatterUpdate<int8_t>(const int &inner_size, const int &indices_size, const int *indices, | |||
| const int8_t *updates, int8_t *output, cudaStream_t cuda_stream); | |||
| const int8_t *updates, int8_t *input, cudaStream_t cuda_stream); | |||
| @@ -20,7 +20,7 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, | |||
| void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ | |||
| @@ -73,13 +73,23 @@ class _ScatterOp_Dynamic(PrimitiveWithCheck): | |||
| """ | |||
| Defines Scatter operators with dynamic shape | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||
| sig.make_sig('updates', dtype=sig.sig_dtype.T) | |||
| ) | |||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | |||
| if np.all(np.array(x_shape) != -1): | |||
| if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||
| raise ValueError(f"For '{prim_name}', " | |||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||
| # x_shape cannot be dynamic | |||
| if np.any(np.array(x_shape) == -1): | |||
| raise ValueError(f"x does not support dynamic shape") | |||
| # support indices and updates dynamic | |||
| if np.any(np.array(indices_shape) == -1) or np.any(np.array(updates_shape) == -1): | |||
| pass | |||
| elif indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||
| raise ValueError(f"For '{prim_name}', " | |||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| @@ -3176,7 +3186,7 @@ class ScatterUpdate(_ScatterOp_Dynamic): | |||
| Tensor, has the same shape and type as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``GPU`` | |||
| Examples: | |||
| >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) | |||
| @@ -56,8 +56,9 @@ class TestScatterAddDynamicNet(nn.Cell): | |||
| self.updates = Parameter(updates, name="updates") | |||
| def construct(self): | |||
| out = self.test_dynamic(self.inputx) | |||
| out = self.scatter_add(out, self.indices, self.updates) | |||
| indices = self.test_dynamic(self.indices) | |||
| updates = self.test_dynamic(self.updates) | |||
| out = self.scatter_add(self.inputx, indices, updates) | |||
| return out | |||
| def scatter_add_d_net(inputx, indices, updates): | |||
| @@ -66,22 +67,24 @@ def scatter_add_d_net(inputx, indices, updates): | |||
| return net() | |||
| class TestScatterAddDynamicNet2(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, inputx): | |||
| super(TestScatterAddDynamicNet2, self).__init__() | |||
| self.scatter_add = P.ScatterAdd() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.inputx = Parameter(inputx, name="inputx") | |||
| def construct(self, inputx, indices, updates): | |||
| out = self.test_dynamic(inputx) | |||
| out = self.scatter_add(out, indices, updates) | |||
| def construct(self, indices, updates): | |||
| indices = self.test_dynamic(indices) | |||
| updates = self.test_dynamic(updates) | |||
| out = self.scatter_add(self.inputx, indices, updates) | |||
| return out | |||
| def scatter_add_d2_net(inputx_1, indices_1, updates_1, inputx_2, | |||
| def scatter_add_d2_net(inputx, indices_1, updates_1, | |||
| indices_2, updates_2): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterAddDynamicNet2() | |||
| out1 = net(inputx_1, indices_1, updates_1) | |||
| out2 = net(inputx_2, indices_2, updates_2) | |||
| net = TestScatterAddDynamicNet2(inputx) | |||
| out1 = net(indices_1, updates_1) | |||
| out2 = net(indices_2, updates_2) | |||
| return (out1, out2) | |||
| @pytest.mark.level0 | |||
| @@ -96,6 +99,20 @@ def test_scatter_add_small_float32(): | |||
| [12., 14., 16.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_add_input_updated(): | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) | |||
| updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) | |||
| lock = True | |||
| net = TestScatterAddNet(lock, inputx, indices, updates) | |||
| net() | |||
| expected = np.array([[6., 8., 10.], | |||
| [12., 14., 16.]]) | |||
| np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -274,39 +291,16 @@ def test_scatter_add_input_less_than_1_dynamic_float32(): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_add_dynamic_two_inputs(): | |||
| inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) | |||
| updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) | |||
| inputx_2 = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) | |||
| indices_2 = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) | |||
| updates_2 = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) | |||
| output_1, output_2 = scatter_add_d2_net(inputx_1, indices_1, updates_1, | |||
| inputx_2, indices_2, updates_2) | |||
| indices_2 = Tensor(np.array([[0, 0], [1, 1], [1, 0]]).astype(np.int32)) | |||
| updates_2 = Tensor(np.flip(np.arange(18).reshape((3, 2, 3)).astype(np.float32))) | |||
| output_1, output_2 = scatter_add_d2_net(inputx, indices_1, updates_1, | |||
| indices_2, updates_2) | |||
| expected_1 = np.array([[6., 8., 10.], | |||
| [12., 14., 16.]]) | |||
| expected_2 = np.array([[[[1., 2., 3., 4.], | |||
| [5., 6., 7., 8.], | |||
| [9., 10., 11., 12.]], | |||
| [[13., 14., 15., 16.], | |||
| [17., 18., 19., 20.], | |||
| [21., 22., 23., 24.]]], | |||
| [[[73., 74., 75., 76.], | |||
| [77., 78., 79., 80.], | |||
| [81., 82., 83., 84.]], | |||
| [[85., 86., 87., 88.], | |||
| [89., 90., 91., 92.], | |||
| [93., 94., 95., 96.]]], | |||
| [[[25., 26., 27., 28.], | |||
| [29., 30., 31., 32.], | |||
| [33., 34., 35., 36.]], | |||
| [[37., 38., 39., 40.], | |||
| [41., 42., 43., 44.], | |||
| [45., 46., 47., 48.]]], | |||
| [[[49., 50., 51., 52.], | |||
| [53., 54., 55., 56.], | |||
| [57., 58., 59., 60.]], | |||
| [[61., 62., 63., 64.], | |||
| [65., 66., 67., 68.], | |||
| [69., 70., 71., 72.]]]]) | |||
| expected_2 = np.array([[39., 38., 37.], | |||
| [36., 35., 34.]]) | |||
| np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) | |||
| np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) | |||
| @@ -50,8 +50,9 @@ class TestScatterUpdateDynamicNet(nn.Cell): | |||
| self.updates = Parameter(updates, name="updates") | |||
| def construct(self): | |||
| out = self.test_dynamic(self.inputx) | |||
| out = self.scatter_update(out, self.indices, self.updates) | |||
| indices = self.test_dynamic(self.indices) | |||
| updates = self.test_dynamic(self.updates) | |||
| out = self.scatter_update(self.inputx, indices, updates) | |||
| return out | |||
| def scatter_update_d_net(inputx, indices, updates): | |||
| @@ -60,22 +61,24 @@ def scatter_update_d_net(inputx, indices, updates): | |||
| return net() | |||
| class TestScatterUpdateDynamicNet2(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, inputx): | |||
| super(TestScatterUpdateDynamicNet2, self).__init__() | |||
| self.scatter_update = P.ScatterUpdate() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.inputx = Parameter(inputx, name="inputx") | |||
| def construct(self, inputx, indices, updates): | |||
| out = self.test_dynamic(inputx) | |||
| out = self.scatter_update(out, indices, updates) | |||
| def construct(self, indices, updates): | |||
| indices = self.test_dynamic(indices) | |||
| updates = self.test_dynamic(updates) | |||
| out = self.scatter_update(self.inputx, indices, updates) | |||
| return out | |||
| def scatter_update_d2_net(inputx_1, indices_1, updates_1, inputx_2, | |||
| def scatter_update_d2_net(inputx, indices_1, updates_1, | |||
| indices_2, updates_2): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterUpdateDynamicNet2() | |||
| out1 = net(inputx_1, indices_1, updates_1) | |||
| out2 = net(inputx_2, indices_2, updates_2) | |||
| net = TestScatterUpdateDynamicNet2(inputx) | |||
| out1 = net(indices_1, updates_1) | |||
| out2 = net(indices_2, updates_2) | |||
| return (out1, out2) | |||
| @pytest.mark.level0 | |||
| @@ -90,6 +93,19 @@ def test_scatter_update_small_float32(): | |||
| [3., 4., 5.]]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_input_updated(): | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) | |||
| net = TestScatterUpdateNet(inputx, indices, updates) | |||
| net() | |||
| expected = np.array([[0., 1., 2.], | |||
| [3., 4., 5.]]) | |||
| np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -328,20 +344,16 @@ def test_scatter_update_disordered_dynamic_int32(): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_two_inputs(): | |||
| inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices_1 = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) | |||
| inputx_2 = Tensor(np.array([[0.214141, 0.415151, 0.51516], | |||
| [0.876542, 0.451611, 0.55112], | |||
| [0.111244, 0.633333, 0.34444]]).astype(np.float32)) | |||
| indices_2 = Tensor(np.array([1, 0, 2]).astype(np.int32)) | |||
| updates_2 = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32)) | |||
| output_1, output_2 = scatter_update_d2_net(inputx_1, indices_1, updates_1, | |||
| inputx_2, indices_2, updates_2) | |||
| indices_2 = Tensor(np.array([1]).astype(np.int32)) | |||
| updates_2 = Tensor(np.arange(34, 37).reshape((1, 3)).astype(np.float32)) | |||
| output_1, output_2 = scatter_update_d2_net(inputx, indices_1, updates_1, | |||
| indices_2, updates_2) | |||
| expected_1 = np.array([[0., 1., 2.], | |||
| [3., 4., 5.]]) | |||
| expected_2 = np.array([[37., 38., 39.], | |||
| [34., 35., 36.], | |||
| [40., 41., 42.]], dtype=np.float32) | |||
| [3., 4., 5.]], dtype=np.float32) | |||
| expected_2 = np.array([[0., 1., 2.], | |||
| [34., 35., 36.]], dtype=np.float32) | |||
| np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) | |||
| np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) | |||