|
|
@@ -120,12 +120,14 @@ def _typecheck_input(x1_type, x2_type): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
@constexpr |
|
|
def _validate_input(x1_shape, x2_shape, axes): |
|
|
|
|
|
|
|
|
def _axes_int_check(x1_shape, x2_shape, axes): |
|
|
""" |
|
|
""" |
|
|
Convert from single int axes to 2d tuple if required and check for validity with inputs. |
|
|
|
|
|
|
|
|
Convert from single int axes to 2d tuple if required |
|
|
""" |
|
|
""" |
|
|
if isinstance(axes, int): |
|
|
if isinstance(axes, int): |
|
|
if axes <= 0: |
|
|
|
|
|
|
|
|
if axes < 0: |
|
|
|
|
|
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}") |
|
|
|
|
|
if axes == 0: |
|
|
# outer product, no input validation required |
|
|
# outer product, no input validation required |
|
|
return ([], []) |
|
|
return ([], []) |
|
|
if axes > len(x1_shape) or axes > len(x2_shape): |
|
|
if axes > len(x1_shape) or axes > len(x2_shape): |
|
|
@@ -135,13 +137,42 @@ def _validate_input(x1_shape, x2_shape, axes): |
|
|
x2_ind = tuple(range(len(x2_shape))[:axes]) |
|
|
x2_ind = tuple(range(len(x2_shape))[:axes]) |
|
|
axes = tuple((x1_ind, x2_ind)) |
|
|
axes = tuple((x1_ind, x2_ind)) |
|
|
axes = _int_to_tuple_conv(axes) |
|
|
axes = _int_to_tuple_conv(axes) |
|
|
for i in range(len(axes[0])): # sizes already validated |
|
|
|
|
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"Given Axes are incompatible with given input arrays") |
|
|
|
|
|
return axes |
|
|
return axes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
|
|
def _validate_axes(x1_shape, x2_shape, axes): |
|
|
|
|
|
""" |
|
|
|
|
|
Checks for axes having the correct length according to input, for any value in axis |
|
|
|
|
|
being out of range with given shape and also checking for compatiable axes values |
|
|
|
|
|
with given inputs. |
|
|
|
|
|
""" |
|
|
|
|
|
shapes = [x1_shape, x2_shape] |
|
|
|
|
|
|
|
|
|
|
|
# axis length check |
|
|
|
|
|
for ix_input, x_axes in enumerate(axes): |
|
|
|
|
|
axes_len = len(x_axes) |
|
|
|
|
|
shape_dim_len = len(shapes[ix_input]) |
|
|
|
|
|
if axes_len > shape_dim_len: |
|
|
|
|
|
raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} " |
|
|
|
|
|
f"can only be max: {shape_dim_len} due to input shape.") |
|
|
|
|
|
|
|
|
|
|
|
# axis values range check |
|
|
|
|
|
for ix_input, x_axes in enumerate(axes): |
|
|
|
|
|
comp_shape = shapes[ix_input] |
|
|
|
|
|
max_val = len(comp_shape) - 1 |
|
|
|
|
|
min_val = -1 * len(comp_shape) |
|
|
|
|
|
for _, x_value in enumerate(x_axes): |
|
|
|
|
|
if not min_val <= x_value <= max_val: |
|
|
|
|
|
raise ValueError(f"axes for input: {ix_input + 1} contains index: " |
|
|
|
|
|
f"{x_value}, but range is: [{min_val}, {max_val}]") |
|
|
|
|
|
|
|
|
|
|
|
# axis value input compatibility |
|
|
|
|
|
for i in range(len(axes[0])): # sizes already validated |
|
|
|
|
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"Given Axes are incompatible with given input arrays") |
|
|
|
|
|
|
|
|
@constexpr |
|
|
@constexpr |
|
|
def _calc_new_shape(shape, axes, position=0): |
|
|
def _calc_new_shape(shape, axes, position=0): |
|
|
""" |
|
|
""" |
|
|
@@ -208,7 +239,8 @@ def tensor_dot(x1, x2, axes): |
|
|
axes = _check_axes(axes) |
|
|
axes = _check_axes(axes) |
|
|
_typecheck_input(x1_type, x2_type) |
|
|
_typecheck_input(x1_type, x2_type) |
|
|
# input compability check & axes format update |
|
|
# input compability check & axes format update |
|
|
axes = _validate_input(x1_shape, x2_shape, axes) |
|
|
|
|
|
|
|
|
axes = _axes_int_check(x1_shape, x2_shape, axes) |
|
|
|
|
|
_validate_axes(x1_shape, x2_shape, axes) |
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) |
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) |
|
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) |
|
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) |
|
|
output_shape = x1_ret + x2_ret # combine free axes from both inputs |
|
|
output_shape = x1_ret + x2_ret # combine free axes from both inputs |
|
|
|