From 971c2d3c6dc348f44e29e60f64a3c7997c41e491 Mon Sep 17 00:00:00 2001 From: liangzelang Date: Mon, 19 Apr 2021 17:37:03 +0800 Subject: [PATCH] [control flow]update st testcases for while bp --- tests/st/control/test_cont_grad.py | 138 +++++++++++++++++++++++++---- 1 file changed, 120 insertions(+), 18 deletions(-) diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py index ae87f6e317..2874174dd4 100644 --- a/tests/st/control/test_cont_grad.py +++ b/tests/st/control/test_cont_grad.py @@ -104,6 +104,10 @@ def test_while_with_const_param_grad(): assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_variable_grad(): class MyWhileNet(nn.Cell): def __init__(self): @@ -166,7 +170,10 @@ def test_while_with_param_forward(): expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32) assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_endless_case(): """endless case when optimization""" class MyWhileNet(nn.Cell): @@ -235,6 +242,10 @@ def test_while_with_param_grad(): expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32) assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_forward_with_const_branch(): class MyWhileNet(nn.Cell): def __init__(self): @@ -266,7 +277,10 @@ def test_while_with_param_forward_with_const_branch(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_opt_endless(): """endless during optimization case""" class MyWhileNet(nn.Cell): @@ -307,6 +321,12 @@ def test_while_opt_endless(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) + +@pytest.mark.skip(reason="not supported yet") +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_no_while_call(): class MyWhileNet(nn.Cell): def __init__(self): @@ -336,7 +356,10 @@ def test_no_while_call(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_grad_with_const_branch(): class MyWhileNet(nn.Cell): def __init__(self): @@ -377,6 +400,11 @@ def test_while_with_param_grad_with_const_branch(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.skip(reason="not supported yet") +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_for_while_with_param_grad_with_const_branch(): class MyWhileNet(nn.Cell): def __init__(self): @@ -420,6 +448,10 @@ def test_for_while_with_param_grad_with_const_branch(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_for_while_with_param_grad_basic(): class MyWhileNet(nn.Cell): def __init__(self): @@ -460,6 +492,10 @@ def test_for_while_with_param_grad_basic(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_for_while_with_param_grad_normal(): class MyWhileNet(nn.Cell): def __init__(self): @@ -500,6 +536,10 @@ def test_for_while_with_param_grad_normal(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_basic_grad(): class MyWhileNet(nn.Cell): def __init__(self): @@ -537,6 +577,10 @@ def test_while_with_param_basic_grad(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_basic_grad_mul(): class MyWhileNet(nn.Cell): def __init__(self): @@ -574,6 +618,10 @@ def test_while_with_param_basic_grad_mul(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_basic_grad_two(): class MyWhileNet(nn.Cell): def __init__(self): @@ -613,6 +661,10 @@ def test_while_with_param_basic_grad_two(): assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_basic_grad_three(): class MyWhileNet(nn.Cell): def __init__(self): @@ -654,6 +706,10 @@ def test_while_with_param_basic_grad_three(): assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_if_with_param_grad(): class MyWhileNet(nn.Cell): def __init__(self): @@ -694,6 +750,11 @@ def test_while_if_with_param_grad(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.skip(reason="not supported yet") +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_grad_not_enter_while(): class MyWhileNet(nn.Cell): def __init__(self): @@ -730,6 +791,10 @@ def test_while_with_param_grad_not_enter_while(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_with_param_if_by_if_forward(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -762,7 +827,10 @@ def test_with_param_if_by_if_forward(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_with_param_if_by_if_grad_inputs(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -801,6 +869,10 @@ def test_with_param_if_by_if_grad_inputs(): assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_with_param_if_by_if_grad_parameter(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -838,6 +910,10 @@ def test_with_param_if_by_if_grad_parameter(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_with_param_if_by_if_grad_param_excute_null(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -873,6 +949,10 @@ def test_with_param_if_by_if_grad_param_excute_null(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_return_inside_grad(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -910,6 +990,10 @@ def test_if_by_if_return_inside_grad(): pynative_output = net(idx, end, x) assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -948,7 +1032,10 @@ def test_if_by_if_forward(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_control_tuple_switch(): """tuple_get from switch op will generate new switch inside to eliminate tuple_get""" class Branch3Net(nn.Cell): @@ -1012,9 +1099,10 @@ def test_if_by_if_forward_control_tuple_switch(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - - - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_control_inside_net(): class Branch3Net(nn.Cell): def __init__(self): @@ -1077,8 +1165,10 @@ def test_if_by_if_forward_control_inside_net(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_use_namespace(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -1117,7 +1207,10 @@ def test_if_by_if_forward_use_namespace(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_use_global_op(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -1160,7 +1253,10 @@ def test_if_by_if_forward_use_global_op(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_for_with_if_by_if_forward(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -1190,8 +1286,10 @@ def test_for_with_if_by_if_forward(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_for_with_if_by_if_forward_namespace(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -1224,7 +1322,10 @@ def test_for_with_if_by_if_forward_namespace(): assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_const_branch_inner(): class MyIfByIfNet(nn.Cell): def __init__(self): @@ -1267,9 +1368,10 @@ def test_if_by_if_forward_const_branch_inner(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - - - +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_if_by_if_forward_all_const_branch(): class MyIfByIfNet(nn.Cell): def __init__(self):