| @@ -372,6 +372,8 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): | |||||
| axes[0] += len(x1_shape) | axes[0] += len(x1_shape) | ||||
| if axes[1] < 0: | if axes[1] < 0: | ||||
| axes[1] += len(x2_shape) | axes[1] += len(x2_shape) | ||||
| validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot') | |||||
| validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot') | |||||
| if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): | if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): | ||||
| raise ValueError( | raise ValueError( | ||||
| "Axes value too high for given input arrays dimensions.") | "Axes value too high for given input arrays dimensions.") | ||||
| @@ -380,6 +382,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): | |||||
| raise ValueError("Batch dim cannot be used as in axes.") | raise ValueError("Batch dim cannot be used as in axes.") | ||||
| if axes < 0: | if axes < 0: | ||||
| axes = [axes + len(x1_shape), axes + len(x2_shape)] | axes = [axes + len(x1_shape), axes + len(x2_shape)] | ||||
| validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot') | |||||
| elif axes > len(x1_shape) or axes > len(x2_shape): | elif axes > len(x1_shape) or axes > len(x2_shape): | ||||
| raise ValueError( | raise ValueError( | ||||
| "Axes value too high for given input arrays dimensions.") | "Axes value too high for given input arrays dimensions.") | ||||
| @@ -448,11 +451,12 @@ def batch_dot(x1, x2, axes=None): | |||||
| Raises: | Raises: | ||||
| TypeError: If type of x1 and x2 are not the same. | TypeError: If type of x1 and x2 are not the same. | ||||
| TpyeError: If dtype of x1 or x2 is not float32. | |||||
| ValueError: If rank of x1 or x2 less than 2. | ValueError: If rank of x1 or x2 less than 2. | ||||
| ValueError: If batch dim used in axes. | ValueError: If batch dim used in axes. | ||||
| ValueError: If dtype of x1 or x2 is not float32. | |||||
| ValueError: If len(axes) less than 2. | ValueError: If len(axes) less than 2. | ||||
| ValueError: If axes is not one of those: None, int, (int, int). | ValueError: If axes is not one of those: None, int, (int, int). | ||||
| ValueError: If axes reversed from negative int is too low for dimensions of input arrays. | |||||
| ValueError: If axes value is too high for dimensions of input arrays. | ValueError: If axes value is too high for dimensions of input arrays. | ||||
| ValueError: If batch size of x1 and x2 are not the same. | ValueError: If batch size of x1 and x2 are not the same. | ||||