diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index eedbdb6500..19828d3871 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2222,7 +2222,7 @@ class ScatterMax(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target parameter. - - **indices** (Tensor) - The index to do max operation whose data type should be int. + - **indices** (Tensor) - The index to do max operation whose data type should be mindspore.int32. - **updates** (Tensor) - The tensor doing the maximum operation with `input_x`, the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. @@ -2249,7 +2249,7 @@ class ScatterMax(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) args = {"x": x_dtype, "updates": updates_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -2266,7 +2266,7 @@ class ScatterAdd(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target parameter. - - **indices** (Tensor) - The index to do add operation whose data type should be int. + - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32. - **updates** (Tensor) - The tensor doing the add operation with `input_x`, the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. @@ -2292,7 +2292,7 @@ class ScatterAdd(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) args = {'x': x_dtype, 'updates': updates_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype