|
|
|
@@ -372,6 +372,8 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): |
|
|
|
axes[0] += len(x1_shape) |
|
|
|
if axes[1] < 0: |
|
|
|
axes[1] += len(x2_shape) |
|
|
|
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot') |
|
|
|
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot') |
|
|
|
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): |
|
|
|
raise ValueError( |
|
|
|
"Axes value too high for given input arrays dimensions.") |
|
|
|
@@ -380,6 +382,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): |
|
|
|
raise ValueError("Batch dim cannot be used as in axes.") |
|
|
|
if axes < 0: |
|
|
|
axes = [axes + len(x1_shape), axes + len(x2_shape)] |
|
|
|
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot') |
|
|
|
elif axes > len(x1_shape) or axes > len(x2_shape): |
|
|
|
raise ValueError( |
|
|
|
"Axes value too high for given input arrays dimensions.") |
|
|
|
@@ -448,11 +451,12 @@ def batch_dot(x1, x2, axes=None): |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If type of x1 and x2 are not the same. |
|
|
|
TpyeError: If dtype of x1 or x2 is not float32. |
|
|
|
ValueError: If rank of x1 or x2 less than 2. |
|
|
|
ValueError: If batch dim used in axes. |
|
|
|
ValueError: If dtype of x1 or x2 is not float32. |
|
|
|
ValueError: If len(axes) less than 2. |
|
|
|
ValueError: If axes is not one of those: None, int, (int, int). |
|
|
|
ValueError: If axes reversed from negative int is too low for dimensions of input arrays. |
|
|
|
ValueError: If axes value is too high for dimensions of input arrays. |
|
|
|
ValueError: If batch size of x1 and x2 are not the same. |
|
|
|
|
|
|
|
|