|
|
|
@@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): |
|
|
|
raise ValueError("Require two axes inputs, given less") |
|
|
|
if isinstance(axes, tuple): |
|
|
|
axes = list(axes) |
|
|
|
for sub_axes in axes: |
|
|
|
if isinstance(sub_axes, (list, tuple)): |
|
|
|
raise ValueError("Require dimension to be in any of those: None, int, (int, int).") |
|
|
|
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot') |
|
|
|
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot') |
|
|
|
# Reverse if axis < 0 |
|
|
|
if axes[0] < 0: |
|
|
|
axes[0] += len(x1_shape) |
|
|
|
if axes[1] < 0: |
|
|
|
axes[1] += len(x2_shape) |
|
|
|
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): |
|
|
|
raise ValueError( |
|
|
|
"Axes value too high for given input arrays dimensions.") |
|
|
|
elif isinstance(axes, int): |
|
|
|
if axes == 0: |
|
|
|
raise ValueError("Batch dim cannot be used as in axes.") |
|
|
|
@@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None): |
|
|
|
""" |
|
|
|
Computation of batch dot product between samples in two tensors containing batch dims. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
output = x1[batch, :] * x2[batch, :] |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 |
|
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should |
|
|
|
@@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None): |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and |
|
|
|
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4) |
|
|
|
|
|
|
|
.. math:: |
|
|
|
output = x1[batch, :] * x2[batch, :] |
|
|
|
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4) |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If shapes of x1 and x2 are not the same. |
|
|
|
TypeError: If type of x1 and x2 are not the same. |
|
|
|
ValueError: If rank of x1 or x2 less than 2. |
|
|
|
ValueError: If batch dim used in axes. |
|
|
|
ValueError: If dtype of x1 or x2 is not float32. |
|
|
|
|