| @@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class _ScatterOp_Dynamic(PrimitiveWithCheck): | |||
| """ | |||
| Defines Scatter operators with dynamic shape | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||
| sig.make_sig('updates', dtype=sig.sig_dtype.T) | |||
| ) | |||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | |||
| if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||
| raise ValueError(f"For '{prim_name}', " | |||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| """Initialize _ScatterOp_Dynamic""" | |||
| validator.check_value_type('use_locking', use_locking, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | |||
| def check_shape(self, x_shape, indices_shape, updates_shape): | |||
| self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) | |||
| def check_dtype(self, x_dtype, indices_dtype, updates_dtype): | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "updates": updates_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| class _ScatterNdOp(_ScatterOp): | |||
| """ | |||
| Defines _ScatterNd operators | |||
| @@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class ScatterUpdate(_ScatterOp): | |||
| class ScatterUpdate(_ScatterOp_Dynamic): | |||
| """ | |||
| Updates tensor value by using input indices and value. | |||
| @@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp): | |||
| [[2.0, 1.2, 1.0], | |||
| [3.0, 1.2, 1.0]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| """Initialize ScatterUpdate""" | |||
| validator.check_value_type('use_locking', use_locking, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| class ScatterNdUpdate(_ScatterNdOp): | |||
| """ | |||
| Updates tensor value by using input indices and value. | |||
| @@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp): | |||
| """ | |||
| class ScatterAdd(_ScatterOp): | |||
| class ScatterAdd(_ScatterOp_Dynamic): | |||
| """ | |||
| Updates the value of the input tensor through the add operation. | |||
| @@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp): | |||
| >>> output = scatter_add(input_x, indices, updates) | |||
| [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| """Initialize ScatterAdd""" | |||
| validator.check_value_type('use_locking', use_locking, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | |||
| class ScatterSub(_ScatterOp): | |||