From: @TFbunny Reviewed-by: @robingrosman Signed-off-by: @robingrosmantags/v1.2.0-rc1
| @@ -49,6 +49,38 @@ MS_REG_GPU_KERNEL_TWO( | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherV2GpuFwdKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| GatherV2GpuFwdKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), | |||
| GatherV2GpuFwdKernel, int, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), | |||
| GatherV2GpuFwdKernel, int16_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), | |||
| GatherV2GpuFwdKernel, int16_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), | |||
| GatherV2GpuFwdKernel, int8_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), | |||
| GatherV2GpuFwdKernel, int8_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), | |||
| GatherV2GpuFwdKernel, uint8_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), | |||
| GatherV2GpuFwdKernel, uint8_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(Gather, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -59,3 +59,21 @@ template void GatherV2<double, int>(double *input, int *indices, double *output, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<int, int>(int *input, int *indices, int *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<int16_t, int64_t>(int16_t *input, int64_t *indices, int16_t *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, | |||
| cudaStream_t stream); | |||
| template void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, | |||
| cudaStream_t stream); | |||
| @@ -1300,3 +1300,15 @@ def get_bprop_lin_space(self): | |||
| return zeros_like(start), zeros_like(stop), zeros_like(num) | |||
| return bprop | |||
| @bprop_getters.register(P.IndexAdd) | |||
| def get_bprop_index_add(self): | |||
| """Generate bprop for IndexAdd""" | |||
| gather = P.Gather() | |||
| _axis = self.axis | |||
| def bprop(input_x, indices, input_y, out, dout): | |||
| return dout, zeros_like(indices), gather(dout, indices, _axis) | |||
| return bprop | |||
| @@ -1178,3 +1178,183 @@ def test_gather1_float64(): | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather1_int32(): | |||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape(2, 3, 4, 5)) | |||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||
| expect = np.array([[[[1., 3., 4.], | |||
| [6., 8., 9.], | |||
| [11., 13., 14.], | |||
| [16., 18., 19.]], | |||
| [[21., 23., 24.], | |||
| [26., 28., 29.], | |||
| [31., 33., 34.], | |||
| [36., 38., 39.]], | |||
| [[41., 43., 44.], | |||
| [46., 48., 49.], | |||
| [51., 53., 54.], | |||
| [56., 58., 59.]]], | |||
| [[[61., 63., 64.], | |||
| [66., 68., 69.], | |||
| [71., 73., 74.], | |||
| [76., 78., 79.]], | |||
| [[81., 83., 84.], | |||
| [86., 88., 89.], | |||
| [91., 93., 94.], | |||
| [96., 98., 99.]], | |||
| [[101., 103., 104.], | |||
| [106., 108., 109.], | |||
| [111., 113., 114.], | |||
| [116., 118., 119.]]]]).astype(np.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| gather = GatherNet1() | |||
| output = gather(x, indices) | |||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather1_int16(): | |||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int16).reshape(2, 3, 4, 5)) | |||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||
| expect = np.array([[[[1., 3., 4.], | |||
| [6., 8., 9.], | |||
| [11., 13., 14.], | |||
| [16., 18., 19.]], | |||
| [[21., 23., 24.], | |||
| [26., 28., 29.], | |||
| [31., 33., 34.], | |||
| [36., 38., 39.]], | |||
| [[41., 43., 44.], | |||
| [46., 48., 49.], | |||
| [51., 53., 54.], | |||
| [56., 58., 59.]]], | |||
| [[[61., 63., 64.], | |||
| [66., 68., 69.], | |||
| [71., 73., 74.], | |||
| [76., 78., 79.]], | |||
| [[81., 83., 84.], | |||
| [86., 88., 89.], | |||
| [91., 93., 94.], | |||
| [96., 98., 99.]], | |||
| [[101., 103., 104.], | |||
| [106., 108., 109.], | |||
| [111., 113., 114.], | |||
| [116., 118., 119.]]]]).astype(np.int16) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| gather = GatherNet1() | |||
| output = gather(x, indices) | |||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather1_int8(): | |||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int8).reshape(2, 3, 4, 5)) | |||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||
| expect = np.array([[[[1., 3., 4.], | |||
| [6., 8., 9.], | |||
| [11., 13., 14.], | |||
| [16., 18., 19.]], | |||
| [[21., 23., 24.], | |||
| [26., 28., 29.], | |||
| [31., 33., 34.], | |||
| [36., 38., 39.]], | |||
| [[41., 43., 44.], | |||
| [46., 48., 49.], | |||
| [51., 53., 54.], | |||
| [56., 58., 59.]]], | |||
| [[[61., 63., 64.], | |||
| [66., 68., 69.], | |||
| [71., 73., 74.], | |||
| [76., 78., 79.]], | |||
| [[81., 83., 84.], | |||
| [86., 88., 89.], | |||
| [91., 93., 94.], | |||
| [96., 98., 99.]], | |||
| [[101., 103., 104.], | |||
| [106., 108., 109.], | |||
| [111., 113., 114.], | |||
| [116., 118., 119.]]]]).astype(np.int8) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| gather = GatherNet1() | |||
| output = gather(x, indices) | |||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather1_uint8(): | |||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.uint8).reshape(2, 3, 4, 5)) | |||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||
| expect = np.array([[[[1., 3., 4.], | |||
| [6., 8., 9.], | |||
| [11., 13., 14.], | |||
| [16., 18., 19.]], | |||
| [[21., 23., 24.], | |||
| [26., 28., 29.], | |||
| [31., 33., 34.], | |||
| [36., 38., 39.]], | |||
| [[41., 43., 44.], | |||
| [46., 48., 49.], | |||
| [51., 53., 54.], | |||
| [56., 58., 59.]]], | |||
| [[[61., 63., 64.], | |||
| [66., 68., 69.], | |||
| [71., 73., 74.], | |||
| [76., 78., 79.]], | |||
| [[81., 83., 84.], | |||
| [86., 88., 89.], | |||
| [91., 93., 94.], | |||
| [96., 98., 99.]], | |||
| [[101., 103., 104.], | |||
| [106., 108., 109.], | |||
| [111., 113., 114.], | |||
| [116., 118., 119.]]]]).astype(np.uint8) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| gather = GatherNet1() | |||
| output = gather(x, indices) | |||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @@ -16,10 +16,12 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| class NetIndexAdd(nn.Cell): | |||
| @@ -257,3 +259,110 @@ def test_index_add_invalid_inputs(): | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert "out of range" in str(info.value) | |||
| class IndexAddGradNet(nn.Cell): | |||
| def __init__(self, network): | |||
| super(IndexAddGradNet, self).__init__() | |||
| self.grad = C.GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, x, idx, y, dout): | |||
| out = self.grad(self.network)(x, idx, y, dout) | |||
| return out | |||
| def index_add_grad_with_type(nptype): | |||
| net = NetIndexAdd(1) | |||
| grad_net = IndexAddGradNet(net) | |||
| x = Tensor(np.arange(15).reshape(5, 3).astype(nptype)) | |||
| y = Tensor(np.arange(5).reshape(5, 1).astype(nptype)) | |||
| dout = Tensor(np.array([[63., 64., 65.], | |||
| [66., 67., 68.], | |||
| [69., 70., 71.], | |||
| [72., 73., 74.], | |||
| [75., 76., 77.]]).astype(nptype)) | |||
| index = Tensor(np.array([1]), dtype=mindspore.int32) | |||
| xgrad, _, ygrad = grad_net(x, index, y, dout) | |||
| expect_xgrad = np.array([[63., 64., 65.], | |||
| [66., 67., 68.], | |||
| [69., 70., 71.], | |||
| [72., 73., 74.], | |||
| [75., 76., 77.]]).astype(nptype) | |||
| expect_ygrad = np.array([[64.], | |||
| [67.], | |||
| [70.], | |||
| [73.], | |||
| [76.]]).astype(nptype) | |||
| np.testing.assert_array_equal(xgrad.asnumpy(), expect_xgrad) | |||
| np.testing.assert_array_equal(ygrad.asnumpy(), expect_ygrad) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_float64(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float64) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_float32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_float16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float16) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_int32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_int16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int16) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_int8(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int8) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.int8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_grad_uint8(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.uint8) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| index_add_grad_with_type(np.uint8) | |||