Browse Source

change index_add op input_x to type Parameter

tags/v1.2.0-rc1
tom__chen 4 years ago
parent
commit
56e8aa9f7e
2 changed files with 74 additions and 65 deletions
  1. +9
    -4
      mindspore/ops/operations/math_ops.py
  2. +65
    -61
      tests/st/ops/gpu/test_index_add_op.py

+ 9
- 4
mindspore/ops/operations/math_ops.py View File

@@ -4403,7 +4403,7 @@ class IndexAdd(PrimitiveWithInfer):
axis (int): The dimension along which to index. axis (int): The dimension along which to index.


Inputs: Inputs:
- **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
- **input_x** (Parameter) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
int8, uint8. int8, uint8.
- **indices** (Tensor) - The index of `input_x` on the `axis`th dimension to add to, with data type int32. - **indices** (Tensor) - The index of `input_x` on the `axis`th dimension to add to, with data type int32.
The `indices` must be 1D with the size same as the size of the `axis`th dimension of `input_y`. The values The `indices` must be 1D with the size same as the size of the `axis`th dimension of `input_y`. The values
@@ -4428,21 +4428,26 @@ class IndexAdd(PrimitiveWithInfer):
[ 5. 5. 7.5] [ 5. 5. 7.5]
[ 8. 7. 10.5]] [ 8. 7. 10.5]]
""" """
__mindspore_signature__ = (
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('input_y', dtype=sig.sig_dtype.T)
)


@prim_attr_register @prim_attr_register
def __init__(self, axis, use_lock=True, check_index_bound=True): def __init__(self, axis, use_lock=True, check_index_bound=True):
"""Initialize InplaceAdd""" """Initialize InplaceAdd"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
self.init_prim_io_names(inputs=['input_x', 'indices', 'input_y'], outputs=['output'])
self.axis = axis self.axis = axis
validator.check_value_type('axis', axis, [int], self.name) validator.check_value_type('axis', axis, [int], self.name)


def infer_dtype(self, x_dtype, idx_type, y_dtype): def infer_dtype(self, x_dtype, idx_type, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
args = {'input_x': x_dtype, 'input_y': y_dtype}
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8, valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8,
mstype.uint8] mstype.uint8]
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
valid_idx_type = [mstype.int32] valid_idx_type = [mstype.int32]
validator.check_tensor_dtype_valid("idx_type", idx_type, valid_idx_type, self.name)
validator.check_tensor_dtype_valid('indices', idx_type, valid_idx_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, idx_shape, y_shape): def infer_shape(self, x_shape, idx_shape, y_shape):


+ 65
- 61
tests/st/ops/gpu/test_index_add_op.py View File

@@ -19,18 +19,19 @@ import pytest
import mindspore import mindspore
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C




class NetIndexAdd(nn.Cell): class NetIndexAdd(nn.Cell):
def __init__(self, axis):
def __init__(self, x, axis):
super(NetIndexAdd, self).__init__() super(NetIndexAdd, self).__init__()
self.input_x = Parameter(Tensor(x), name='x')
self.index_add = P.IndexAdd(axis) self.index_add = P.IndexAdd(axis)


def construct(self, x, idx, y):
z = self.index_add(x, idx, y)
def construct(self, idx, y):
z = self.index_add(self.input_x, idx, y)
return z return z




@@ -45,12 +46,12 @@ def test_index_add():
expect = np.copy(x) expect = np.copy(x)
expect[idx0, :, :, :] = expect[idx0, :, :, :] + y0 expect[idx0, :, :, :] = expect[idx0, :, :, :] + y0
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis0)
output = net(Tensor(x), Tensor(idx0), Tensor(y0))
net = NetIndexAdd(x, axis0)
output = net(Tensor(idx0), Tensor(y0))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis0)
output = net(Tensor(x), Tensor(idx0), Tensor(y0))
net = NetIndexAdd(x, axis0)
output = net(Tensor(idx0), Tensor(y0))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()


y1 = np.ndarray((2, 2, 4, 4)).astype(np.float32) y1 = np.ndarray((2, 2, 4, 4)).astype(np.float32)
@@ -60,12 +61,12 @@ def test_index_add():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx1, :, :] = expect[:, idx1, :, :] + y1 expect[:, idx1, :, :] = expect[:, idx1, :, :] + y1
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
net = NetIndexAdd(axis1)
output = net(Tensor(x), Tensor(idx1), Tensor(y1))
net = NetIndexAdd(x, axis1)
output = net(Tensor(idx1), Tensor(y1))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis1)
output = net(Tensor(x), Tensor(idx1), Tensor(y1))
net = NetIndexAdd(x, axis1)
output = net(Tensor(idx1), Tensor(y1))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()


y2 = np.ones((2, 3, 2, 4)).astype(np.float32) y2 = np.ones((2, 3, 2, 4)).astype(np.float32)
@@ -75,12 +76,12 @@ def test_index_add():
expect = np.copy(x) expect = np.copy(x)
expect[:, :, idx2, :] = expect[:, :, idx2, :] + y2 expect[:, :, idx2, :] = expect[:, :, idx2, :] + y2
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
net = NetIndexAdd(axis2)
output = net(Tensor(x), Tensor(idx2), Tensor(y2))
net = NetIndexAdd(x, axis2)
output = net(Tensor(idx2), Tensor(y2))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis2)
output = net(Tensor(x), Tensor(idx2), Tensor(y2))
net = NetIndexAdd(x, axis2)
output = net(Tensor(idx2), Tensor(y2))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()


y3 = np.ones((2, 3, 4, 3)).astype(np.float32) y3 = np.ones((2, 3, 4, 3)).astype(np.float32)
@@ -90,12 +91,12 @@ def test_index_add():
expect = np.copy(x) expect = np.copy(x)
expect[:, :, :, idx3] = expect[:, :, :, idx3] + y3 expect[:, :, :, idx3] = expect[:, :, :, idx3] + y3
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
net = NetIndexAdd(axis3)
output = net(Tensor(x), Tensor(idx3), Tensor(y3))
net = NetIndexAdd(x, axis3)
output = net(Tensor(idx3), Tensor(y3))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis3)
output = net(Tensor(x), Tensor(idx3), Tensor(y3))
net = NetIndexAdd(x, axis3)
output = net(Tensor(idx3), Tensor(y3))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -110,12 +111,12 @@ def test_index_add_float16():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -130,12 +131,12 @@ def test_index_add_int32():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -150,12 +151,12 @@ def test_index_add_int8():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -170,12 +171,12 @@ def test_index_add_uint8():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -190,12 +191,12 @@ def test_index_add_float64():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -210,12 +211,12 @@ def test_index_add_int16():
expect = np.copy(x) expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y expect[:, idx, :] = expect[:, idx, :] + y
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIndexAdd(axis)
output = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()




@@ -227,55 +228,56 @@ def test_index_add_invalid_inputs():
y = np.ones((2, 2, 4), dtype=np.uint8) y = np.ones((2, 2, 4), dtype=np.uint8)
with pytest.raises(TypeError): with pytest.raises(TypeError):
#axis not int #axis not int
net = NetIndexAdd(1.0)
net = NetIndexAdd(x, 1.0)


#x and y don't have the same type #x and y don't have the same type
y = np.ones((2, 2, 4), dtype=np.float32) y = np.ones((2, 2, 4), dtype=np.float32)
idx = np.array([0, 1]).astype(np.int32) idx = np.array([0, 1]).astype(np.int32)
net = NetIndexAdd(1)
_ = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))


with pytest.raises(ValueError): with pytest.raises(ValueError):
#index size not the same as len(y[axis]) #index size not the same as len(y[axis])
idx = np.array([0]).astype(np.int32) idx = np.array([0]).astype(np.int32)
net = NetIndexAdd(1)
_ = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))


#x and y don't have same rank #x and y don't have same rank
y = np.ones((2, 2), dtype=np.uint8) y = np.ones((2, 2), dtype=np.uint8)
idx = np.array([0, 1]).astype(np.int32) idx = np.array([0, 1]).astype(np.int32)
net = NetIndexAdd(1)
_ = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))


#x and y don't have same shape on dimensions other than axis-th dimension #x and y don't have same shape on dimensions other than axis-th dimension
y = np.ones((2, 2, 5), dtype=np.uint8) y = np.ones((2, 2, 5), dtype=np.uint8)
idx = np.array([0, 1]).astype(np.int32) idx = np.array([0, 1]).astype(np.int32)
net = NetIndexAdd(1)
_ = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))


with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
#index value not in the range of 0 to len(x[axis]) #index value not in the range of 0 to len(x[axis])
idx = np.array([5, 6]).astype(np.int32) idx = np.array([5, 6]).astype(np.int32)
net = NetIndexAdd(1)
_ = net(Tensor(x), Tensor(idx), Tensor(y))
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))
assert "out of range" in str(info.value) assert "out of range" in str(info.value)




class IndexAddGradNet(nn.Cell): class IndexAddGradNet(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(IndexAddGradNet, self).__init__() super(IndexAddGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network self.network = network
self.params = ParameterTuple(network.trainable_params())


def construct(self, x, idx, y, dout):
out = self.grad(self.network)(x, idx, y, dout)
def construct(self, idx, y, dout):
out = self.grad(self.network, self.params)(idx, y, dout)
return out return out




def index_add_grad_with_type(nptype): def index_add_grad_with_type(nptype):
net = NetIndexAdd(1)
x = np.arange(15).reshape(5, 3).astype(nptype)
net = NetIndexAdd(x, 1)
grad_net = IndexAddGradNet(net) grad_net = IndexAddGradNet(net)
x = Tensor(np.arange(15).reshape(5, 3).astype(nptype))
y = Tensor(np.arange(5).reshape(5, 1).astype(nptype)) y = Tensor(np.arange(5).reshape(5, 1).astype(nptype))
dout = Tensor(np.array([[63., 64., 65.], dout = Tensor(np.array([[63., 64., 65.],
[66., 67., 68.], [66., 67., 68.],
@@ -283,7 +285,9 @@ def index_add_grad_with_type(nptype):
[72., 73., 74.], [72., 73., 74.],
[75., 76., 77.]]).astype(nptype)) [75., 76., 77.]]).astype(nptype))
index = Tensor(np.array([1]), dtype=mindspore.int32) index = Tensor(np.array([1]), dtype=mindspore.int32)
xgrad, _, ygrad = grad_net(x, index, y, dout)
output = grad_net(index, y, dout)
ygrad = output[0][1]
xgrad = output[1][0]
expect_xgrad = np.array([[63., 64., 65.], expect_xgrad = np.array([[63., 64., 65.],
[66., 67., 68.], [66., 67., 68.],
[69., 70., 71.], [69., 70., 71.],


Loading…
Cancel
Save