Browse Source

fixed ScatterUpdate

tags/v0.5.0-beta
jiangjinsheng 5 years ago
parent
commit
dc548afb93
1 changed files with 14 additions and 9 deletions
  1. +14
    -9
      mindspore/ops/operations/array_ops.py

+ 14
- 9
mindspore/ops/operations/array_ops.py View File

@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices. Creates an empty tensor, and set values by scattering the update tensor depending on indices.


Inputs: Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor.
- **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered. - **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.


@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):


Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor. With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:]. and update.shape = indices.shape + input_x.shape[1:].


@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.


Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate() >>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterUpdate""" """Init ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])


def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):


Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor, with int32 data type.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.


Outputs: Outputs:
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.


Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterNdUpdate""" """Init ScatterNdUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])


def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype


Loading…
Cancel
Save