Browse Source

!1705 fix issue scatter_add and scatter_max indices type limited to int32

Merge pull request !1705 from zhaozhenlong/fix-issues-scatter-add-max-indices
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0fd5e702e1
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/ops/operations/array_ops.py

+ 4
- 4
mindspore/ops/operations/array_ops.py View File

@@ -2222,7 +2222,7 @@ class ScatterMax(PrimitiveWithInfer):

Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **indices** (Tensor) - The index to do max operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.

@@ -2249,7 +2249,7 @@ class ScatterMax(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
@@ -2266,7 +2266,7 @@ class ScatterAdd(PrimitiveWithInfer):

Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be int.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.

@@ -2292,7 +2292,7 @@ class ScatterAdd(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype


Loading…
Cancel
Save