Browse Source

!2661 add support dtype for scatter_add vm

Merge pull request !2661 from zhaozhenlong/op/scatter-add
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
6fb5538117
2 changed files with 29 additions and 2 deletions
  1. +2
    -0
      mindspore/ops/_op_impl/tbe/scatter_add.py
  2. +27
    -2
      tests/ut/python/ops/test_ops.py

+ 2
- 0
mindspore/ops/_op_impl/tbe/scatter_add.py View File

@@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info() .get_op_info()






+ 27
- 2
tests/ut/python/ops/test_ops.py View File

@@ -220,10 +220,10 @@ class ScatterMax(nn.Cell):
class ScatterAdd(nn.Cell): class ScatterAdd(nn.Cell):
"""ScatterAdd net definition""" """ScatterAdd net definition"""


def __init__(self, ref_shape):
def __init__(self, ref_shape, dtype=np.float32):
super(ScatterAdd, self).__init__() super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd() self.scatter_add = P.ScatterAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")


def construct(self, indices, updates): def construct(self, indices, updates):
out = self.scatter_add(self.ref, indices, updates) out = self.scatter_add(self.ref, indices, updates)
@@ -1677,12 +1677,37 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))), Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}), 'skip': ['backward']}),
('ScatterAddScalar', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd2d', { ('ScatterAdd2d', {
'block': ScatterAdd((3, 4)), 'block': ScatterAdd((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}), 'skip': ['backward']}),
('ScatterAddF16', {
'block': ScatterAdd((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterAddI8', {
'block': ScatterAdd((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterAddI32', {
'block': ScatterAdd((6,), np.int32),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int32))),
'skip': ['backward']}),
('ScatterAddU8', {
'block': ScatterAdd((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('SmoothL1Loss', { ('SmoothL1Loss', {
'block': P.SmoothL1Loss(), 'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]], 'desc_inputs': [[256, 4], [256, 4]],


Loading…
Cancel
Save