Browse Source

Modify dot opt to support pynative mode

tags/v1.2.0-rc1
xutianming 4 years ago
parent
commit
1e7f18a097
2 changed files with 25 additions and 2 deletions
  1. +8
    -2
      mindspore/ops/composite/math_ops.py
  2. +17
    -0
      tests/st/ops/cpu/test_dot_op.py

+ 8
- 2
mindspore/ops/composite/math_ops.py View File

@@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape):
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') + f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')




@constexpr
def _get_transpose_shape(x2_shape):
x2_shape_range = tuple(range(len(x2_shape)))
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
return x2_shape_transpose


def dot(x1, x2): def dot(x1, x2):
""" """
Computation a dot product between samples in two tensors. Computation a dot product between samples in two tensors.
@@ -304,8 +311,7 @@ def dot(x1, x2):
_check_invalid_input(x1_shape, x2_shape) _check_invalid_input(x1_shape, x2_shape)


if len(x1_shape) > 2 or len(x2_shape) > 2: if len(x1_shape) > 2 or len(x2_shape) > 2:
x2_shape_range = range(len(x2_shape))
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
x2_shape_transpose = _get_transpose_shape(x2_shape)
x2_transpose = transpose_op(x2, x2_shape_transpose) x2_transpose = transpose_op(x2, x2_shape_transpose)
x1_reshape = reshape_op(x1, (-1, x1_shape[-1])) x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1)) x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))


+ 17
- 0
tests/st/ops/cpu/test_dot_op.py View File

@@ -207,3 +207,20 @@ def test_dot_010():
[[3., 3.]]]).astype(np.float32) [[3., 3.]]]).astype(np.float32)


assert (ms_result_np.asnumpy() == expect_result).all() assert (ms_result_np.asnumpy() == expect_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dot_011():
# for document
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
input_x1 = Tensor(np.array(np.ones(shape=[2, 3])).astype(np.float32))
input_x2 = Tensor(np.array(np.ones(shape=[1, 3, 2])).astype(np.float32))

network = NetDot()
ms_result_np = network(input_x1, input_x2)
expect_result = np.array([[[3., 3.]],
[[3., 3.]]]).astype(np.float32)

assert (ms_result_np.asnumpy() == expect_result).all()

Loading…
Cancel
Save