|
|
|
@@ -336,6 +336,17 @@ def _get_batch_size(x1_shape, x2_shape): |
|
|
|
return x1_shape[0], x2_shape[0] |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _typecheck_input_batch_dot(x1_type, x2_type): |
|
|
|
""" |
|
|
|
Check input tensor types to be valid and confirm they are the same type for batch dot ops. |
|
|
|
""" |
|
|
|
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1') |
|
|
|
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2') |
|
|
|
if x1_type != x2_type: |
|
|
|
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): |
|
|
|
""" |
|
|
|
@@ -419,15 +430,29 @@ def batch_dot(x1, x2, axes=None): |
|
|
|
Computation of batch dot product between samples in two tensors containing batch dims. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32 |
|
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should |
|
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 |
|
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should |
|
|
|
be same as x1's. |
|
|
|
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions |
|
|
|
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from |
|
|
|
`a` input shape and last N dims from `b` input shape in order as axes for each respectively. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, batch dot product of x1 and x2. |
|
|
|
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and |
|
|
|
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4) |
|
|
|
|
|
|
|
.. math:: |
|
|
|
output = x1[batch, :] * x2[batch, :] |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If shapes of x1 and x2 are not the same. |
|
|
|
ValueError: If rank of x1 or x2 less than 2. |
|
|
|
ValueError: If batch dim used in axes. |
|
|
|
ValueError: If dtype of x1 or x2 is not float32. |
|
|
|
ValueError: If len(axes) less than 2. |
|
|
|
ValueError: If axes is not one of those: None, int, (int, int). |
|
|
|
ValueError: If axes value is too high for dimensions of input arrays. |
|
|
|
ValueError: If batch size of x1 and x2 are not the same. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
@@ -458,7 +483,7 @@ def batch_dot(x1, x2, axes=None): |
|
|
|
|
|
|
|
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape) |
|
|
|
|
|
|
|
_typecheck_input(x1_type, x2_type) |
|
|
|
_typecheck_input_batch_dot(x1_type, x2_type) |
|
|
|
_check_batch_size(x1_batch_size, x2_batch_size) |
|
|
|
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes) |
|
|
|
|
|
|
|
|