Browse Source

!30701 Add some st test cases for hypermap

Merge pull request !30701 from LiangZhibo/hypermap
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
9869865c6a
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 91 additions and 2 deletions
  1. +91
    -2
      tests/st/hypermap/test_hypermap.py

+ 91
- 2
tests/st/hypermap/test_hypermap.py View File

@@ -32,6 +32,11 @@ double_elements_fg = C.MultitypeFuncGraph("double_elements_fg")
def double_elements_fg_for_tensor_tuple(x, y):
return P.Tile()(x, y)

@double_elements_fg.register("Tensor", "List")
def double_elements_fg_for_tensor_list(x, y):
return x + y[0]


class HyperMapNet(nn.Cell):
def __init__(self, fg):
super(HyperMapNet, self).__init__()
@@ -47,7 +52,7 @@ class HyperMapNet(nn.Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_element_hypermap():
def test_single_element_hypermap_with_tensor_input():
"""
Feature: HyperMap
Description: Test whether the HyperMap with single tensor input can run successfully.
@@ -70,7 +75,7 @@ def test_single_element_hypermap():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap():
def test_double_elements_hypermap_tensor_tuple_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with tensor and tuple inputs can run successfully.
@@ -88,3 +93,87 @@ def test_double_elements_hypermap():
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_tensor_list_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with tensor and list inputs can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ([1, 2], [2, 1])
common_map = HyperMapNet(double_elements_fg)
output = common_map((x, y))
expect_output_1 = np.array([2.0, 3.0, 4.0])
expect_output_2 = np.array([6.0, 7.0, 8.0])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_doubel_elements_hypermap_correct_mix_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with mix correct inputs (Tensor + Tuple and Tensor + List)
can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ((1, 2), [2, 1])
common_map = HyperMapNet(double_elements_fg)
output = common_map((x, y))
expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
expect_output_2 = np.array([6.0, 7.0, 8.0])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)



@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_inputs_length_mismatch():
"""
Feature: HyperMap
Description: When the inputs to hypermap have different length, error will be raised.
Expectation: error.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ((1, 2), (2, 1), (5, 6))
common_map = HyperMapNet(double_elements_fg)
with pytest.raises(Exception, match="The length of tuples in HyperMap must be the same"):
common_map((x, y))


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_inconsistent_inputs():
"""
Feature: HyperMap
Description: When the inputs to hypermap is inconsistent, error will be raised.
Expectation: error.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = [(1, 2), (2, 1)]
common_map = HyperMapNet(double_elements_fg)
with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"):
common_map((x, y))

Loading…
Cancel
Save