From 171cd83188f7f784081ca81737d1962e6218492c Mon Sep 17 00:00:00 2001 From: He Wei Date: Thu, 29 Apr 2021 09:28:24 +0800 Subject: [PATCH] [test] Fix GRAPH_MODE not restored after PYNATIVE_MODE is set --- .../auto_monad/test_auto_monad_mindtester.py | 31 ++++++++++++------- tests/st/auto_monad/test_effect_random.py | 3 ++ .../models/resnet50/test_resnet50_imagenet.py | 5 +++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/st/auto_monad/test_auto_monad_mindtester.py b/tests/st/auto_monad/test_auto_monad_mindtester.py index 9e9b257690..9c3f9874ac 100644 --- a/tests/st/auto_monad/test_auto_monad_mindtester.py +++ b/tests/st/auto_monad/test_auto_monad_mindtester.py @@ -550,9 +550,12 @@ def test_side_effect_grad_two_addn_switch(): inputs = Tensor([9.0], ms.float32) out1 = net.grad_mindspore_impl(inputs, grad_ys) net = SideEffectTwoAddnSwitchNet() - context.set_context(mode=context.PYNATIVE_MODE) - out2 = net.grad_mindspore_impl(inputs, grad_ys) - allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001) + try: + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net.grad_mindspore_impl(inputs, grad_ys) + allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001) + finally: + context.set_context(mode=context.GRAPH_MODE) class SideEffectGradIfNet(Cell): @@ -590,9 +593,12 @@ def test_side_effect_grad_if(): inputs = Tensor([9.0], ms.float32) out1 = net.grad_mindspore_impl(inputs, grad_ys) net = SideEffectGradIfNet() - context.set_context(mode=context.PYNATIVE_MODE) - out2 = net.grad_mindspore_impl(inputs, grad_ys) - allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001) + try: + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net.grad_mindspore_impl(inputs, grad_ys) + allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001) + finally: + context.set_context(mode=context.GRAPH_MODE) class OneInputBprop(Cell): @@ -683,8 +689,11 @@ def test_side_effect_grad_control_flow_assign_depend_while_net(): inputs2 = Tensor([6.0], ms.float32) inputs3 = Tensor([3.0], ms.float32) out1 = net.grad_mindspore_impl(inputs1, inputs2, inputs3, grad_ys) - context.set_context(mode=context.PYNATIVE_MODE) - net = SideEffectControlFlowAssignDependWhileNet() - out2 = net.grad_mindspore_impl(inputs1, inputs2, inputs3, grad_ys) - allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001) - allclose_nparray(out1[1][0].asnumpy(), out2[1][0].asnumpy(), 0.001, 0.001) + try: + context.set_context(mode=context.PYNATIVE_MODE) + net = SideEffectControlFlowAssignDependWhileNet() + out2 = net.grad_mindspore_impl(inputs1, inputs2, inputs3, grad_ys) + allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001) + allclose_nparray(out1[1][0].asnumpy(), out2[1][0].asnumpy(), 0.001, 0.001) + finally: + context.set_context(mode=context.GRAPH_MODE) diff --git a/tests/st/auto_monad/test_effect_random.py b/tests/st/auto_monad/test_effect_random.py index a5a3aab2cf..9fc18c63d2 100644 --- a/tests/st/auto_monad/test_effect_random.py +++ b/tests/st/auto_monad/test_effect_random.py @@ -29,6 +29,7 @@ class Sampling(nn.Cell): """ Test class: sample of Normal distribution. """ + def __init__(self, shape, seed=0): super(Sampling, self).__init__() self.n1 = msd.Normal(0, 1, seed=seed, dtype=dtype.float32) @@ -400,6 +401,8 @@ class RandomChoiceWithMaskNet(nn.Cell): @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_random_choice_with_mask(): + mode = context.get_context('mode') + assert (mode == context.GRAPH_MODE), 'GRAPH_MODE required but got ' + str(mode) net = RandomChoiceWithMaskNet() x = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) index1, index2, index3 = net(x) diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index f25d83b32d..9cf5686924 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -312,6 +312,11 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): @pytest.mark.platform_x86_ascend_training @pytest.mark.env_single def test_resnet_and_resnet_thor_imagenet_4p(): + # reset context + context.set_context(save_graphs=False, enable_graph_kernel=False, enable_sparse=False) + context.reset_auto_parallel_context() + context.reset_ps_context() + q = Queue() q2 = Queue()