|
|
|
@@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class _ScatterOp_Dynamic(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Defines Scatter operators with dynamic shape |
|
|
|
""" |
|
|
|
__mindspore_signature__ = ( |
|
|
|
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), |
|
|
|
sig.make_sig('indices', dtype=sig.sig_dtype.T1), |
|
|
|
sig.make_sig('updates', dtype=sig.sig_dtype.T) |
|
|
|
) |
|
|
|
|
|
|
|
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): |
|
|
|
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
|
raise ValueError(f"For '{prim_name}', " |
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " |
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Initialize _ScatterOp_Dynamic""" |
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
|
|
|
|
def check_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
|
|
|
|
def check_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) |
|
|
|
args = {"x": x_dtype, "updates": updates_dtype} |
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) |
|
|
|
|
|
|
|
|
|
|
|
class _ScatterNdOp(_ScatterOp): |
|
|
|
""" |
|
|
|
Defines _ScatterNd operators |
|
|
|
@@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterUpdate(_ScatterOp): |
|
|
|
class ScatterUpdate(_ScatterOp_Dynamic): |
|
|
|
""" |
|
|
|
Updates tensor value by using input indices and value. |
|
|
|
|
|
|
|
@@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp): |
|
|
|
[[2.0, 1.2, 1.0], |
|
|
|
[3.0, 1.2, 1.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=True): |
|
|
|
"""Initialize ScatterUpdate""" |
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): |
|
|
|
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) |
|
|
|
args = {"x": x_dtype, "value": value_dtype} |
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterNdUpdate(_ScatterNdOp): |
|
|
|
""" |
|
|
|
Updates tensor value by using input indices and value. |
|
|
|
@@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp): |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
class ScatterAdd(_ScatterOp): |
|
|
|
class ScatterAdd(_ScatterOp_Dynamic): |
|
|
|
""" |
|
|
|
Updates the value of the input tensor through the add operation. |
|
|
|
|
|
|
|
@@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp): |
|
|
|
>>> output = scatter_add(input_x, indices, updates) |
|
|
|
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] |
|
|
|
""" |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Initialize ScatterAdd""" |
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
|
|
|
|
|
|
|
|
class ScatterSub(_ScatterOp): |
|
|
|
|