Browse Source

!12668 [GPU] fix index_add incorrect shape constraint

From: @tom__chen
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
70b0cb3c30
1 changed files with 3 additions and 5 deletions
  1. +3
    -5
      mindspore/ops/operations/math_ops.py

+ 3
- 5
mindspore/ops/operations/math_ops.py View File

@@ -4397,10 +4397,10 @@ class MatrixInverse(PrimitiveWithInfer):


class IndexAdd(PrimitiveWithInfer): class IndexAdd(PrimitiveWithInfer):
""" """
Adds tenosr y to specified axis and indices of tensor x.
Adds tensor y to specified axis and indices of tensor x.


Args: Args:
axis (int): The dimension along wich to index.
axis (int): The dimension along which to index.


Inputs: Inputs:
- **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16, - **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
@@ -4453,8 +4453,6 @@ class IndexAdd(PrimitiveWithInfer):
validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name) validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name)
axis = self.axis if self.axis >= 0 else x_rank + self.axis axis = self.axis if self.axis >= 0 else x_rank + self.axis
for dim in range(x_rank): for dim in range(x_rank):
if dim == axis:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.GE, self.name)
else:
if dim != axis:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name) validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name)
return x_shape return x_shape

Loading…
Cancel
Save