diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index d6f197b3a2..c9f1406880 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -372,6 +372,8 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): axes[0] += len(x1_shape) if axes[1] < 0: axes[1] += len(x2_shape) + validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot') + validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot') if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): raise ValueError( "Axes value too high for given input arrays dimensions.") @@ -380,6 +382,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): raise ValueError("Batch dim cannot be used as in axes.") if axes < 0: axes = [axes + len(x1_shape), axes + len(x2_shape)] + validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot') elif axes > len(x1_shape) or axes > len(x2_shape): raise ValueError( "Axes value too high for given input arrays dimensions.") @@ -448,11 +451,12 @@ def batch_dot(x1, x2, axes=None): Raises: TypeError: If type of x1 and x2 are not the same. + TpyeError: If dtype of x1 or x2 is not float32. ValueError: If rank of x1 or x2 less than 2. ValueError: If batch dim used in axes. - ValueError: If dtype of x1 or x2 is not float32. ValueError: If len(axes) less than 2. ValueError: If axes is not one of those: None, int, (int, int). + ValueError: If axes reversed from negative int is too low for dimensions of input arrays. ValueError: If axes value is too high for dimensions of input arrays. ValueError: If batch size of x1 and x2 are not the same.