Browse Source

!14833 Bug fix and Raise added for dot ops

From: @anrui-wang
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui,@c_34
pull/14833/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
fbe8906e00
1 changed files with 19 additions and 0 deletions
  1. +19
    -0
      mindspore/ops/composite/math_ops.py

+ 19
- 0
mindspore/ops/composite/math_ops.py View File

@@ -279,6 +279,17 @@ def _check_invalid_input(x1_shape, x2_shape):
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')


@constexpr
def _typecheck_input_dot(x1_type, x2_type):
"""
Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
"""
const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
if x1_type != x2_type:
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')


@constexpr
def _get_transpose_shape(x2_shape):
x2_shape_range = tuple(range(len(x2_shape)))
@@ -297,6 +308,11 @@ def dot(x1, x2):
Outputs:
Tensor, dot product of x1 and x2.

Raises:
TypeError: If type of x1 and x2 are not the same.
TpyeError: If dtype of x1 or x2 is not float16 or float32.
ValueError: If rank of x1 or x2 less than 2.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

@@ -314,6 +330,9 @@ def dot(x1, x2):
matmul_op = P.MatMul(False, False)
x1_shape = shape_op(x1)
x2_shape = shape_op(x2)
x1_type = F.dtype(x1)
x2_type = F.dtype(x2)
_typecheck_input_dot(x1_type, x2_type)
_check_invalid_input(x1_shape, x2_shape)

if len(x1_shape) > 2 or len(x2_shape) > 2:


Loading…
Cancel
Save