Browse Source

TensorDot checking refactor

lintfix
tags/v1.1.0
danishnxt 5 years ago
parent
commit
a7ca8a4c1d
1 changed files with 40 additions and 8 deletions
  1. +40
    -8
      mindspore/ops/composite/math_ops.py

+ 40
- 8
mindspore/ops/composite/math_ops.py View File

@@ -120,12 +120,14 @@ def _typecheck_input(x1_type, x2_type):


@constexpr
def _validate_input(x1_shape, x2_shape, axes):
def _axes_int_check(x1_shape, x2_shape, axes):
"""
Convert from single int axes to 2d tuple if required and check for validity with inputs.
Convert from single int axes to 2d tuple if required
"""
if isinstance(axes, int):
if axes <= 0:
if axes < 0:
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
if axes == 0:
# outer product, no input validation required
return ([], [])
if axes > len(x1_shape) or axes > len(x2_shape):
@@ -135,13 +137,42 @@ def _validate_input(x1_shape, x2_shape, axes):
x2_ind = tuple(range(len(x2_shape))[:axes])
axes = tuple((x1_ind, x2_ind))
axes = _int_to_tuple_conv(axes)
for i in range(len(axes[0])): # sizes already validated
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
raise ValueError(
"Given Axes are incompatible with given input arrays")
return axes


@constexpr
def _validate_axes(x1_shape, x2_shape, axes):
"""
Checks for axes having the correct length according to input, for any value in axis
being out of range with given shape and also checking for compatiable axes values
with given inputs.
"""
shapes = [x1_shape, x2_shape]

# axis length check
for ix_input, x_axes in enumerate(axes):
axes_len = len(x_axes)
shape_dim_len = len(shapes[ix_input])
if axes_len > shape_dim_len:
raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} "
f"can only be max: {shape_dim_len} due to input shape.")

# axis values range check
for ix_input, x_axes in enumerate(axes):
comp_shape = shapes[ix_input]
max_val = len(comp_shape) - 1
min_val = -1 * len(comp_shape)
for _, x_value in enumerate(x_axes):
if not min_val <= x_value <= max_val:
raise ValueError(f"axes for input: {ix_input + 1} contains index: "
f"{x_value}, but range is: [{min_val}, {max_val}]")

# axis value input compatibility
for i in range(len(axes[0])): # sizes already validated
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
raise ValueError(
"Given Axes are incompatible with given input arrays")

@constexpr
def _calc_new_shape(shape, axes, position=0):
"""
@@ -208,7 +239,8 @@ def tensor_dot(x1, x2, axes):
axes = _check_axes(axes)
_typecheck_input(x1_type, x2_type)
# input compability check & axes format update
axes = _validate_input(x1_shape, x2_shape, axes)
axes = _axes_int_check(x1_shape, x2_shape, axes)
_validate_axes(x1_shape, x2_shape, axes)
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
output_shape = x1_ret + x2_ret # combine free axes from both inputs


Loading…
Cancel
Save