diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index c9f1406880..774f86d6cc 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -279,6 +279,17 @@ def _check_invalid_input(x1_shape, x2_shape): + f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') +@constexpr +def _typecheck_input_dot(x1_type, x2_type): + """ + Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops. + """ + const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1') + const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2') + if x1_type != x2_type: + raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') + + @constexpr def _get_transpose_shape(x2_shape): x2_shape_range = tuple(range(len(x2_shape))) @@ -297,6 +308,11 @@ def dot(x1, x2): Outputs: Tensor, dot product of x1 and x2. + Raises: + TypeError: If type of x1 and x2 are not the same. + TpyeError: If dtype of x1 or x2 is not float16 or float32. + ValueError: If rank of x1 or x2 less than 2. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -314,6 +330,9 @@ def dot(x1, x2): matmul_op = P.MatMul(False, False) x1_shape = shape_op(x1) x2_shape = shape_op(x2) + x1_type = F.dtype(x1) + x2_type = F.dtype(x2) + _typecheck_input_dot(x1_type, x2_type) _check_invalid_input(x1_shape, x2_shape) if len(x1_shape) > 2 or len(x2_shape) > 2: