Browse Source

!8236 add dynamic shape support to scatteradd/update

From: @TFbunny
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3a3b8c04c3
1 changed files with 38 additions and 10 deletions
  1. +38
    -10
      mindspore/ops/operations/array_ops.py

+ 38
- 10
mindspore/ops/operations/array_ops.py View File

@@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer):
return x_dtype


class _ScatterOp_Dynamic(PrimitiveWithCheck):
"""
Defines Scatter operators with dynamic shape
"""
__mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)

def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")

@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize _ScatterOp_Dynamic"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])

def check_shape(self, x_shape, indices_shape, updates_shape):
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)

def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)


class _ScatterNdOp(_ScatterOp):
"""
Defines _ScatterNd operators
@@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_dtype


class ScatterUpdate(_ScatterOp):
class ScatterUpdate(_ScatterOp_Dynamic):
"""
Updates tensor value by using input indices and value.

@@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp):
[[2.0, 1.2, 1.0],
[3.0, 1.2, 1.0]]
"""

@prim_attr_register
def __init__(self, use_locking=True):
"""Initialize ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


class ScatterNdUpdate(_ScatterNdOp):
"""
Updates tensor value by using input indices and value.
@@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp):
"""


class ScatterAdd(_ScatterOp):
class ScatterAdd(_ScatterOp_Dynamic):
"""
Updates the value of the input tensor through the add operation.

@@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp):
>>> output = scatter_add(input_x, indices, updates)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ScatterAdd"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])


class ScatterSub(_ScatterOp):


Loading…
Cancel
Save