|
|
|
@@ -38,6 +38,39 @@ from ..._c_expression import signature_dtype as sig_dtype |
|
|
|
from ..._c_expression import typing |
|
|
|
|
|
|
|
|
|
|
|
class _ScatterOp(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Define Scatter operators |
|
|
|
""" |
|
|
|
__mindspore_signature__ = ( |
|
|
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), |
|
|
|
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) |
|
|
|
) |
|
|
|
@staticmethod |
|
|
|
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): |
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " |
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " |
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Init _ScatterOp""" |
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) |
|
|
|
args = {"x": x_dtype, "updates": updates_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name): |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) |
|
|
|
validator.check_value_type('axis', axis, [int, tuple], prim_name) |
|
|
|
@@ -2221,7 +2254,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterUpdate(PrimitiveWithInfer): |
|
|
|
class ScatterUpdate(_ScatterOp): |
|
|
|
""" |
|
|
|
Update tensor value by using input indices and value. |
|
|
|
|
|
|
|
@@ -2233,8 +2266,8 @@ class ScatterUpdate(PrimitiveWithInfer): |
|
|
|
Inputs: |
|
|
|
- **input_x** (Parameter) - The target tensor, with data type of Parameter. |
|
|
|
- **indices** (Tensor) - The index of input tensor. With int32 data type. |
|
|
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input, |
|
|
|
and update.shape = indices.shape + input_x.shape[1:]. |
|
|
|
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input, |
|
|
|
and updates.shape = indices.shape + input_x.shape[1:]. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, has the same shape and type as `input_x`. |
|
|
|
@@ -2243,27 +2276,17 @@ class ScatterUpdate(PrimitiveWithInfer): |
|
|
|
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) |
|
|
|
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") |
|
|
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) |
|
|
|
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) |
|
|
|
>>> update = Tensor(np_update, mindspore.float32) |
|
|
|
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) |
|
|
|
>>> updates = Tensor(np_updates, mindspore.float32) |
|
|
|
>>> op = P.ScatterUpdate() |
|
|
|
>>> output = op(input_x, indices, update) |
|
|
|
>>> output = op(input_x, indices, updates) |
|
|
|
""" |
|
|
|
__mindspore_signature__ = ( |
|
|
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), |
|
|
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=True): |
|
|
|
"""Init ScatterUpdate""" |
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, value_shape): |
|
|
|
if indices_shape + x_shape[1:] != value_shape: |
|
|
|
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.") |
|
|
|
return x_shape |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) |
|
|
|
@@ -2323,14 +2346,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): |
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " |
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " |
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") |
|
|
|
|
|
|
|
|
|
|
|
class ScatterMax(PrimitiveWithInfer): |
|
|
|
class ScatterMax(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the max operation. |
|
|
|
|
|
|
|
@@ -2364,18 +2380,8 @@ class ScatterMax(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) |
|
|
|
args = {"x": x_dtype, "updates": updates_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterMin(PrimitiveWithInfer): |
|
|
|
class ScatterMin(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the min operation. |
|
|
|
|
|
|
|
@@ -2403,24 +2409,8 @@ class ScatterMin(PrimitiveWithInfer): |
|
|
|
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Init ScatterMin""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) |
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) |
|
|
|
args = {"x": x_dtype, "updates": updates_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterAdd(PrimitiveWithInfer): |
|
|
|
class ScatterAdd(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the add operation. |
|
|
|
|
|
|
|
@@ -2448,23 +2438,8 @@ class ScatterAdd(PrimitiveWithInfer): |
|
|
|
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Init ScatterAdd""" |
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) |
|
|
|
args = {'x': x_dtype, 'updates': updates_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ScatterSub(PrimitiveWithInfer): |
|
|
|
class ScatterSub(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the sub operation. |
|
|
|
|
|
|
|
@@ -2492,20 +2467,63 @@ class ScatterSub(PrimitiveWithInfer): |
|
|
|
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=False): |
|
|
|
"""Init ScatterSub""" |
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape): |
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) |
|
|
|
return x_shape |
|
|
|
class ScatterMul(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the mul operation. |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) |
|
|
|
args = {'x': x_dtype, 'updates': updates_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
Using given values to update tensor value through the mul operation, along with the input indices. |
|
|
|
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. |
|
|
|
|
|
|
|
Args: |
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Parameter) - The target parameter. |
|
|
|
- **indices** (Tensor) - The index to do mul operation whose data type should be mindspore.int32. |
|
|
|
- **updates** (Tensor) - The tensor doing the mul operation with `input_x`, |
|
|
|
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Parameter, the updated `input_x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") |
|
|
|
>>> indices = Tensor(np.array([0, 1]), mindspore.int32) |
|
|
|
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) |
|
|
|
>>> scatter_mul = P.ScatterMul() |
|
|
|
>>> output = scatter_mul(input_x, indices, updates) |
|
|
|
[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
class ScatterDiv(_ScatterOp): |
|
|
|
""" |
|
|
|
Update the value of the input tensor through the div operation. |
|
|
|
|
|
|
|
Using given values to update tensor value through the div operation, along with the input indices. |
|
|
|
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. |
|
|
|
|
|
|
|
Args: |
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Parameter) - The target parameter. |
|
|
|
- **indices** (Tensor) - The index to do div operation whose data type should be mindspore.int32. |
|
|
|
- **updates** (Tensor) - The tensor doing the div operation with `input_x`, |
|
|
|
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Parameter, the updated `input_x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") |
|
|
|
>>> indices = Tensor(np.array([0, 1]), mindspore.int32) |
|
|
|
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) |
|
|
|
>>> scatter_div = P.ScatterDiv() |
|
|
|
>>> output = scatter_div(input_x, indices, updates) |
|
|
|
[[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]] |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
class SpaceToDepth(PrimitiveWithInfer): |
|
|
|
|