|
|
|
@@ -148,7 +148,7 @@ def _axes_int_check(x1_shape, x2_shape, axes): |
|
|
|
def _validate_axes(x1_shape, x2_shape, axes): |
|
|
|
""" |
|
|
|
Checks for axes having the correct length according to input, for any value in axis |
|
|
|
being out of range with given shape and also checking for compatiable axes values |
|
|
|
being out of range with given shape and also checking for compatible axes values |
|
|
|
with given inputs. |
|
|
|
""" |
|
|
|
shapes = [x1_shape, x2_shape] |
|
|
|
@@ -250,7 +250,7 @@ def tensor_dot(x1, x2, axes): |
|
|
|
x2_type = F.dtype(x2) |
|
|
|
axes = _check_axes(axes) |
|
|
|
_typecheck_input(x1_type, x2_type) |
|
|
|
# input compability check & axes format update |
|
|
|
# input compatibility check & axes format update |
|
|
|
axes = _axes_int_check(x1_shape, x2_shape, axes) |
|
|
|
_validate_axes(x1_shape, x2_shape, axes) |
|
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) |
|
|
|
@@ -297,7 +297,7 @@ def dot(x1, x2): |
|
|
|
Examples: |
|
|
|
>>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32) |
|
|
|
>>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32) |
|
|
|
>>> output = C.Dot(input_x1, input_x2) |
|
|
|
>>> output = C.dot(input_x1, input_x2) |
|
|
|
>>> print(output) |
|
|
|
[[[3. 3.]] |
|
|
|
[[3. 3.]]] |
|
|
|
|