Browse Source

add int8 and uint8 support to scatteradd and fix doc example

tags/v1.1.0
TFbunny 5 years ago
parent
commit
fb1a65c469
4 changed files with 95 additions and 2 deletions
  1. +14
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc
  2. +5
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu
  3. +2
    -2
      mindspore/ops/operations/array_ops.py
  4. +74
    -0
      tests/st/ops/gpu/test_scatter_update_op.py

+ 14
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc View File

@@ -39,5 +39,19 @@ MS_REG_GPU_KERNEL_ONE(ScatterUpdate,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterUpdateKernel, int)
MS_REG_GPU_KERNEL_ONE(ScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterUpdateKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(ScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterUpdateKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu View File

@@ -41,3 +41,8 @@ template void CalScatterUpdate<half>(const int &inner_size, const int &indices_s
const half *updates, half *output, cudaStream_t cuda_stream);
template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices,
const int *updates, int *output, cudaStream_t cuda_stream);
template void CalScatterUpdate<unsigned char>(const int &inner_size, const int &indices_size, const int *indices,
const unsigned char *updates, unsigned char *output,
cudaStream_t cuda_stream);
template void CalScatterUpdate<int8_t>(const int &inner_size, const int &indices_size, const int *indices,
const int8_t *updates, int8_t *output, cudaStream_t cuda_stream);

+ 2
- 2
mindspore/ops/operations/array_ops.py View File

@@ -2907,8 +2907,8 @@ class ScatterUpdate(_ScatterOp_Dynamic):
Examples:
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> np_updates = np.array([[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]])
>>> updates = Tensor(np_updates, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, updates)


+ 74
- 0
tests/st/ops/gpu/test_scatter_update_op.py View File

@@ -163,3 +163,77 @@ def test_scatter_update_large_shape_float16():
[88., 89., 90., 91.],
[92., 93., 94., 95.]]]]).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_disordered_int8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
indices = Tensor(np.array([1, 2]).astype(np.int32))
updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int8))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[45., 44., 43., 42.],
[63., 64., 65., 66.],
[67., 68., 69., 70.]]).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_large_shape_int8():
inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.int8))
indices = Tensor(np.array([1, 0]).astype(np.int32))
updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.int8)))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[[[23., 22., 21., 20.],
[19., 18., 17., 16.],
[15., 14., 13., 12.]],
[[11., 10., 9., 8.],
[7., 6., 5., 4.],
[3., 2., 1., 0.]]],
[[[47., 46., 45., 44.],
[43., 42., 41., 40.],
[39., 38., 37., 36.]],
[[35., 34., 33., 32.],
[31., 30., 29., 28.],
[27., 26., 25., 24.]]],
[[[48., 49., 50., 51.],
[52., 53., 54., 55.],
[56., 57., 58., 59.]],
[[60., 61., 62., 63.],
[64., 65., 66., 67.],
[68., 69., 70., 71.]]],
[[[72., 73., 74., 75.],
[76., 77., 78., 79.],
[80., 81., 82., 83.]],
[[84., 85., 86., 87.],
[88., 89., 90., 91.],
[92., 93., 94., 95.]]]]).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_large_uint8():
inputx = Tensor(np.zeros((4, 3)).astype(np.uint8))
indices = Tensor(np.array([[2, 1], [0, 3]]).astype(np.int32))
updates = Tensor(np.arange(63, 75).reshape((2, 2, 3)).astype(np.uint8))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[69., 70., 71.],
[66., 67., 68.],
[63., 64., 65.],
[72., 73., 74.]]).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_disordered_uint8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
indices = Tensor(np.array([1, 2]).astype(np.int32))
updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.uint8))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[45., 44., 43., 42.],
[63., 64., 65., 66.],
[67., 68., 69., 70.]]).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

Loading…
Cancel
Save