Browse Source

!9809 Add int8 uint8 type support to GPU ScatterAdd

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
cdeacdf5c7
3 changed files with 51 additions and 0 deletions
  1. +14
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc
  2. +5
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu
  3. +32
    -0
      tests/st/ops/gpu/test_scatter_add_op.py

+ 14
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc View File

@@ -39,5 +39,19 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterAddKernel, int)
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterAddKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterAddKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu View File

@@ -42,3 +42,8 @@ template void CalScatterAdd<half>(const size_t &inner_size, const size_t &indice
const half *updates, half *input, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const int *updates, int *input, cudaStream_t cuda_stream);
template void CalScatterAdd<unsigned char>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const unsigned char *updates, unsigned char *input,
cudaStream_t cuda_stream);
template void CalScatterAdd<int8_t>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);

+ 32
- 0
tests/st/ops/gpu/test_scatter_add_op.py View File

@@ -269,6 +269,38 @@ def test_scatter_add_disordered_dynamic_int32():
[492., 496., 500., 504.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_dynamic_int8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int8))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_dynamic_uint8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.uint8))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard


Loading…
Cancel
Save