diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index 59a9917a88..5dea841fbf 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -120,12 +120,14 @@ def _typecheck_input(x1_type, x2_type): @constexpr -def _validate_input(x1_shape, x2_shape, axes): +def _axes_int_check(x1_shape, x2_shape, axes): """ - Convert from single int axes to 2d tuple if required and check for validity with inputs. + Convert from single int axes to 2d tuple if required """ if isinstance(axes, int): - if axes <= 0: + if axes < 0: + raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}") + if axes == 0: # outer product, no input validation required return ([], []) if axes > len(x1_shape) or axes > len(x2_shape): @@ -135,13 +137,42 @@ def _validate_input(x1_shape, x2_shape, axes): x2_ind = tuple(range(len(x2_shape))[:axes]) axes = tuple((x1_ind, x2_ind)) axes = _int_to_tuple_conv(axes) - for i in range(len(axes[0])): # sizes already validated - if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: - raise ValueError( - "Given Axes are incompatible with given input arrays") return axes +@constexpr +def _validate_axes(x1_shape, x2_shape, axes): + """ + Checks for axes having the correct length according to input, for any value in axis + being out of range with given shape and also checking for compatiable axes values + with given inputs. + """ + shapes = [x1_shape, x2_shape] + + # axis length check + for ix_input, x_axes in enumerate(axes): + axes_len = len(x_axes) + shape_dim_len = len(shapes[ix_input]) + if axes_len > shape_dim_len: + raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} " + f"can only be max: {shape_dim_len} due to input shape.") + + # axis values range check + for ix_input, x_axes in enumerate(axes): + comp_shape = shapes[ix_input] + max_val = len(comp_shape) - 1 + min_val = -1 * len(comp_shape) + for _, x_value in enumerate(x_axes): + if not min_val <= x_value <= max_val: + raise ValueError(f"axes for input: {ix_input + 1} contains index: " + f"{x_value}, but range is: [{min_val}, {max_val}]") + + # axis value input compatibility + for i in range(len(axes[0])): # sizes already validated + if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: + raise ValueError( + "Given Axes are incompatible with given input arrays") + @constexpr def _calc_new_shape(shape, axes, position=0): """ @@ -208,7 +239,8 @@ def tensor_dot(x1, x2, axes): axes = _check_axes(axes) _typecheck_input(x1_type, x2_type) # input compability check & axes format update - axes = _validate_input(x1_shape, x2_shape, axes) + axes = _axes_int_check(x1_shape, x2_shape, axes) + _validate_axes(x1_shape, x2_shape, axes) x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) output_shape = x1_ret + x2_ret # combine free axes from both inputs