|
|
|
@@ -185,6 +185,10 @@ def TensorDot(x1, x2, axes): |
|
|
|
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) |
|
|
|
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) |
|
|
|
>>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2))) |
|
|
|
>>> print(output) |
|
|
|
[[2,2,2], |
|
|
|
[2,2,2], |
|
|
|
[2,2,2]] |
|
|
|
""" |
|
|
|
shape_op = P.Shape() |
|
|
|
reshape_op = P.Reshape() |
|
|
|
|