|
|
|
@@ -279,6 +279,17 @@ def _check_invalid_input(x1_shape, x2_shape): |
|
|
|
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _typecheck_input_dot(x1_type, x2_type): |
|
|
|
""" |
|
|
|
Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops. |
|
|
|
""" |
|
|
|
const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1') |
|
|
|
const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2') |
|
|
|
if x1_type != x2_type: |
|
|
|
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_transpose_shape(x2_shape): |
|
|
|
x2_shape_range = tuple(range(len(x2_shape))) |
|
|
|
@@ -297,6 +308,11 @@ def dot(x1, x2): |
|
|
|
Outputs: |
|
|
|
Tensor, dot product of x1 and x2. |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If type of x1 and x2 are not the same. |
|
|
|
TpyeError: If dtype of x1 or x2 is not float16 or float32. |
|
|
|
ValueError: If rank of x1 or x2 less than 2. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
|
|
|
|
@@ -314,6 +330,9 @@ def dot(x1, x2): |
|
|
|
matmul_op = P.MatMul(False, False) |
|
|
|
x1_shape = shape_op(x1) |
|
|
|
x2_shape = shape_op(x2) |
|
|
|
x1_type = F.dtype(x1) |
|
|
|
x2_type = F.dtype(x2) |
|
|
|
_typecheck_input_dot(x1_type, x2_type) |
|
|
|
_check_invalid_input(x1_shape, x2_shape) |
|
|
|
|
|
|
|
if len(x1_shape) > 2 or len(x2_shape) > 2: |
|
|
|
|