From 0bdf6c51a712057c7163d2d9549f4dfad59cb322 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Wed, 4 Nov 2020 17:33:15 -0500 Subject: [PATCH] add dynamic shape for ScatterAdd/Update --- mindspore/ops/operations/array_ops.py | 48 +++++++++++++++++++++------ 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 5caee52bb6..049b26c3dc 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer): return x_dtype +class _ScatterOp_Dynamic(PrimitiveWithCheck): + """ + Defines Scatter operators with dynamic shape + """ + __mindspore_signature__ = ( + sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('updates', dtype=sig.sig_dtype.T) + ) + + def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): + if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: + raise ValueError(f"For '{prim_name}', " + f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " + f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") + + @prim_attr_register + def __init__(self, use_locking=False): + """Initialize _ScatterOp_Dynamic""" + validator.check_value_type('use_locking', use_locking, [bool], self.name) + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + + def check_shape(self, x_shape, indices_shape, updates_shape): + self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + + def check_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) + args = {"x": x_dtype, "updates": updates_dtype} + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) + + class _ScatterNdOp(_ScatterOp): """ Defines _ScatterNd operators @@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): return x_dtype -class ScatterUpdate(_ScatterOp): +class ScatterUpdate(_ScatterOp_Dynamic): """ Updates tensor value by using input indices and value. @@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp): [[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]] """ - @prim_attr_register def __init__(self, use_locking=True): """Initialize ScatterUpdate""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) - def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) - args = {"x": x_dtype, "value": value_dtype} - validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) - return x_dtype - - class ScatterNdUpdate(_ScatterNdOp): """ Updates tensor value by using input indices and value. @@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp): """ -class ScatterAdd(_ScatterOp): +class ScatterAdd(_ScatterOp_Dynamic): """ Updates the value of the input tensor through the add operation. @@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp): >>> output = scatter_add(input_x, indices, updates) [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] """ + @prim_attr_register + def __init__(self, use_locking=False): + """Initialize ScatterAdd""" + validator.check_value_type('use_locking', use_locking, [bool], self.name) + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) class ScatterSub(_ScatterOp):