| @@ -120,12 +120,14 @@ def _typecheck_input(x1_type, x2_type): | |||||
| @constexpr | @constexpr | ||||
| def _validate_input(x1_shape, x2_shape, axes): | |||||
| def _axes_int_check(x1_shape, x2_shape, axes): | |||||
| """ | """ | ||||
| Convert from single int axes to 2d tuple if required and check for validity with inputs. | |||||
| Convert from single int axes to 2d tuple if required | |||||
| """ | """ | ||||
| if isinstance(axes, int): | if isinstance(axes, int): | ||||
| if axes <= 0: | |||||
| if axes < 0: | |||||
| raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}") | |||||
| if axes == 0: | |||||
| # outer product, no input validation required | # outer product, no input validation required | ||||
| return ([], []) | return ([], []) | ||||
| if axes > len(x1_shape) or axes > len(x2_shape): | if axes > len(x1_shape) or axes > len(x2_shape): | ||||
| @@ -135,13 +137,42 @@ def _validate_input(x1_shape, x2_shape, axes): | |||||
| x2_ind = tuple(range(len(x2_shape))[:axes]) | x2_ind = tuple(range(len(x2_shape))[:axes]) | ||||
| axes = tuple((x1_ind, x2_ind)) | axes = tuple((x1_ind, x2_ind)) | ||||
| axes = _int_to_tuple_conv(axes) | axes = _int_to_tuple_conv(axes) | ||||
| for i in range(len(axes[0])): # sizes already validated | |||||
| if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: | |||||
| raise ValueError( | |||||
| "Given Axes are incompatible with given input arrays") | |||||
| return axes | return axes | ||||
| @constexpr | |||||
| def _validate_axes(x1_shape, x2_shape, axes): | |||||
| """ | |||||
| Checks for axes having the correct length according to input, for any value in axis | |||||
| being out of range with given shape and also checking for compatiable axes values | |||||
| with given inputs. | |||||
| """ | |||||
| shapes = [x1_shape, x2_shape] | |||||
| # axis length check | |||||
| for ix_input, x_axes in enumerate(axes): | |||||
| axes_len = len(x_axes) | |||||
| shape_dim_len = len(shapes[ix_input]) | |||||
| if axes_len > shape_dim_len: | |||||
| raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} " | |||||
| f"can only be max: {shape_dim_len} due to input shape.") | |||||
| # axis values range check | |||||
| for ix_input, x_axes in enumerate(axes): | |||||
| comp_shape = shapes[ix_input] | |||||
| max_val = len(comp_shape) - 1 | |||||
| min_val = -1 * len(comp_shape) | |||||
| for _, x_value in enumerate(x_axes): | |||||
| if not min_val <= x_value <= max_val: | |||||
| raise ValueError(f"axes for input: {ix_input + 1} contains index: " | |||||
| f"{x_value}, but range is: [{min_val}, {max_val}]") | |||||
| # axis value input compatibility | |||||
| for i in range(len(axes[0])): # sizes already validated | |||||
| if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: | |||||
| raise ValueError( | |||||
| "Given Axes are incompatible with given input arrays") | |||||
| @constexpr | @constexpr | ||||
| def _calc_new_shape(shape, axes, position=0): | def _calc_new_shape(shape, axes, position=0): | ||||
| """ | """ | ||||
| @@ -208,7 +239,8 @@ def tensor_dot(x1, x2, axes): | |||||
| axes = _check_axes(axes) | axes = _check_axes(axes) | ||||
| _typecheck_input(x1_type, x2_type) | _typecheck_input(x1_type, x2_type) | ||||
| # input compability check & axes format update | # input compability check & axes format update | ||||
| axes = _validate_input(x1_shape, x2_shape, axes) | |||||
| axes = _axes_int_check(x1_shape, x2_shape, axes) | |||||
| _validate_axes(x1_shape, x2_shape, axes) | |||||
| x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | ||||
| x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | ||||
| output_shape = x1_ret + x2_ret # combine free axes from both inputs | output_shape = x1_ret + x2_ret # combine free axes from both inputs | ||||