| @@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer): | |||||
| Creates an empty tensor, and set values by scattering the update tensor depending on indices. | Creates an empty tensor, and set values by scattering the update tensor depending on indices. | ||||
| Inputs: | Inputs: | ||||
| - **indices** (Tensor) - The index of scattering in the new tensor. | |||||
| - **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type. | |||||
| - **update** (Tensor) - The source Tensor to be scattered. | - **update** (Tensor) - The source Tensor to be scattered. | ||||
| - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. | - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. | ||||
| @@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer): | |||||
| def __infer__(self, indices, update, shape): | def __infer__(self, indices, update, shape): | ||||
| shp = shape['value'] | shp = shape['value'] | ||||
| validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) | validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_value_type("shape", shp, [tuple], self.name) | validator.check_value_type("shape", shp, [tuple], self.name) | ||||
| for i, x in enumerate(shp): | for i, x in enumerate(shp): | ||||
| validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) | validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) | ||||
| @@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer): | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Parameter) - The target tensor, with data type of Parameter. | - **input_x** (Parameter) - The target tensor, with data type of Parameter. | ||||
| - **indices** (Tensor) - The index of input tensor. | |||||
| - **indices** (Tensor) - The index of input tensor. With int32 data type. | |||||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | ||||
| and update.shape = indices.shape + input_x.shape[1:]. | and update.shape = indices.shape + input_x.shape[1:]. | ||||
| @@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer): | |||||
| Tensor, has the same shape and type as `input_x`. | Tensor, has the same shape and type as `input_x`. | ||||
| Examples: | Examples: | ||||
| >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) | |||||
| >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) | |||||
| >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") | |||||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | ||||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||||
| >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) | |||||
| >>> update = Tensor(np_update, mindspore.float32) | |||||
| >>> op = P.ScatterUpdate() | >>> op = P.ScatterUpdate() | ||||
| >>> output = op(input_x, indices, update) | >>> output = op(input_x, indices, update) | ||||
| """ | """ | ||||
| @@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=True): | def __init__(self, use_locking=True): | ||||
| """Init ScatterUpdate""" | """Init ScatterUpdate""" | ||||
| validator.check_value_type('use_locking', use_locking, [bool], self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | ||||
| def infer_shape(self, x_shape, indices_shape, value_shape): | def infer_shape(self, x_shape, indices_shape, value_shape): | ||||
| @@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | args = {"x": x_dtype, "value": value_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | ||||
| return x_dtype | return x_dtype | ||||
| @@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Parameter) - The target tensor, with data type of Parameter. | - **input_x** (Parameter) - The target tensor, with data type of Parameter. | ||||
| - **indices** (Tensor) - The index of input tensor. | |||||
| - **indices** (Tensor) - The index of input tensor, with int32 data type. | |||||
| - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. | - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. | ||||
| Outputs: | Outputs: | ||||
| Tensor, has the same shape and type as `input_x`. | Tensor, has the same shape and type as `input_x`. | ||||
| Examples: | Examples: | ||||
| >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) | |||||
| >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) | |||||
| >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") | |||||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | ||||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | ||||
| >>> op = P.ScatterNdUpdate() | >>> op = P.ScatterNdUpdate() | ||||
| @@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=True): | def __init__(self, use_locking=True): | ||||
| """Init ScatterNdUpdate""" | """Init ScatterNdUpdate""" | ||||
| validator.check_value_type('use_locking', use_locking, [bool], self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | ||||
| def infer_shape(self, x_shape, indices_shape, value_shape): | def infer_shape(self, x_shape, indices_shape, value_shape): | ||||
| @@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | args = {"x": x_dtype, "value": value_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | ||||
| return x_dtype | return x_dtype | ||||