| @@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): | |||||
| raise ValueError("Require two axes inputs, given less") | raise ValueError("Require two axes inputs, given less") | ||||
| if isinstance(axes, tuple): | if isinstance(axes, tuple): | ||||
| axes = list(axes) | axes = list(axes) | ||||
| for sub_axes in axes: | |||||
| if isinstance(sub_axes, (list, tuple)): | |||||
| raise ValueError("Require dimension to be in any of those: None, int, (int, int).") | |||||
| validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot') | |||||
| validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot') | |||||
| # Reverse if axis < 0 | # Reverse if axis < 0 | ||||
| if axes[0] < 0: | if axes[0] < 0: | ||||
| axes[0] += len(x1_shape) | axes[0] += len(x1_shape) | ||||
| if axes[1] < 0: | if axes[1] < 0: | ||||
| axes[1] += len(x2_shape) | axes[1] += len(x2_shape) | ||||
| if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): | |||||
| raise ValueError( | |||||
| "Axes value too high for given input arrays dimensions.") | |||||
| elif isinstance(axes, int): | elif isinstance(axes, int): | ||||
| if axes == 0: | if axes == 0: | ||||
| raise ValueError("Batch dim cannot be used as in axes.") | raise ValueError("Batch dim cannot be used as in axes.") | ||||
| @@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None): | |||||
| """ | """ | ||||
| Computation of batch dot product between samples in two tensors containing batch dims. | Computation of batch dot product between samples in two tensors containing batch dims. | ||||
| .. math:: | |||||
| output = x1[batch, :] * x2[batch, :] | |||||
| Inputs: | Inputs: | ||||
| - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 | - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 | ||||
| - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should | - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should | ||||
| @@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None): | |||||
| Outputs: | Outputs: | ||||
| Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and | Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and | ||||
| (batch, d3, axes, d4) is (batch, d1, d2, d3, d4) | |||||
| .. math:: | |||||
| output = x1[batch, :] * x2[batch, :] | |||||
| (batch, d3, axes, d4) is (batch, d1, d2, d3, d4) | |||||
| Raises: | Raises: | ||||
| TypeError: If shapes of x1 and x2 are not the same. | |||||
| TypeError: If type of x1 and x2 are not the same. | |||||
| ValueError: If rank of x1 or x2 less than 2. | ValueError: If rank of x1 or x2 less than 2. | ||||
| ValueError: If batch dim used in axes. | ValueError: If batch dim used in axes. | ||||
| ValueError: If dtype of x1 or x2 is not float32. | ValueError: If dtype of x1 or x2 is not float32. | ||||