| @@ -120,12 +120,14 @@ def _typecheck_input(x1_type, x2_type): | |||
| @constexpr | |||
| def _validate_input(x1_shape, x2_shape, axes): | |||
| def _axes_int_check(x1_shape, x2_shape, axes): | |||
| """ | |||
| Convert from single int axes to 2d tuple if required and check for validity with inputs. | |||
| Convert from single int axes to 2d tuple if required | |||
| """ | |||
| if isinstance(axes, int): | |||
| if axes <= 0: | |||
| if axes < 0: | |||
| raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}") | |||
| if axes == 0: | |||
| # outer product, no input validation required | |||
| return ([], []) | |||
| if axes > len(x1_shape) or axes > len(x2_shape): | |||
| @@ -135,13 +137,42 @@ def _validate_input(x1_shape, x2_shape, axes): | |||
| x2_ind = tuple(range(len(x2_shape))[:axes]) | |||
| axes = tuple((x1_ind, x2_ind)) | |||
| axes = _int_to_tuple_conv(axes) | |||
| for i in range(len(axes[0])): # sizes already validated | |||
| if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: | |||
| raise ValueError( | |||
| "Given Axes are incompatible with given input arrays") | |||
| return axes | |||
| @constexpr | |||
| def _validate_axes(x1_shape, x2_shape, axes): | |||
| """ | |||
| Checks for axes having the correct length according to input, for any value in axis | |||
| being out of range with given shape and also checking for compatiable axes values | |||
| with given inputs. | |||
| """ | |||
| shapes = [x1_shape, x2_shape] | |||
| # axis length check | |||
| for ix_input, x_axes in enumerate(axes): | |||
| axes_len = len(x_axes) | |||
| shape_dim_len = len(shapes[ix_input]) | |||
| if axes_len > shape_dim_len: | |||
| raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} " | |||
| f"can only be max: {shape_dim_len} due to input shape.") | |||
| # axis values range check | |||
| for ix_input, x_axes in enumerate(axes): | |||
| comp_shape = shapes[ix_input] | |||
| max_val = len(comp_shape) - 1 | |||
| min_val = -1 * len(comp_shape) | |||
| for _, x_value in enumerate(x_axes): | |||
| if not min_val <= x_value <= max_val: | |||
| raise ValueError(f"axes for input: {ix_input + 1} contains index: " | |||
| f"{x_value}, but range is: [{min_val}, {max_val}]") | |||
| # axis value input compatibility | |||
| for i in range(len(axes[0])): # sizes already validated | |||
| if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: | |||
| raise ValueError( | |||
| "Given Axes are incompatible with given input arrays") | |||
| @constexpr | |||
| def _calc_new_shape(shape, axes, position=0): | |||
| """ | |||
| @@ -208,7 +239,8 @@ def tensor_dot(x1, x2, axes): | |||
| axes = _check_axes(axes) | |||
| _typecheck_input(x1_type, x2_type) | |||
| # input compability check & axes format update | |||
| axes = _validate_input(x1_shape, x2_shape, axes) | |||
| axes = _axes_int_check(x1_shape, x2_shape, axes) | |||
| _validate_axes(x1_shape, x2_shape, axes) | |||
| x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | |||
| x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | |||
| output_shape = x1_ret + x2_ret # combine free axes from both inputs | |||