Browse Source

!3276 make gpu equal op support int32

Merge pull request !3276 from qujianwei/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5f10417b9f
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      tests/st/ops/gpu/test_equal_op.py

+ 11
- 0
tests/st/ops/gpu/test_equal_op.py View File

@@ -60,6 +60,11 @@ def test_equal():
y1_np = np.array([0, 1, -3]).astype(np.float32)
y1 = Tensor(y1_np)
expect1 = np.equal(x1_np, y1_np)
x2_np = np.array([0, 1, 3]).astype(np.int32)
x2 = Tensor(x2_np)
y2_np = np.array([0, 1, -3]).astype(np.int32)
y2 = Tensor(y2_np)
expect2 = np.equal(x2_np, y2_np)

context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
equal = NetEqual()
@@ -69,6 +74,9 @@ def test_equal():
output1 = equal(x1, y1)
assert np.all(output1.asnumpy() == expect1)
assert output1.shape == expect1.shape
output2 = equal(x2, y2)
assert np.all(output2.asnumpy() == expect2)
assert output2.shape == expect2.shape

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
equal = NetEqual()
@@ -78,6 +86,9 @@ def test_equal():
output1 = equal(x1, y1)
assert np.all(output1.asnumpy() == expect1)
assert output1.shape == expect1.shape
output2 = equal(x2, y2)
assert np.all(output2.asnumpy() == expect2)
assert output2.shape == expect2.shape


@pytest.mark.level0


Loading…
Cancel
Save