From 04392f31203973a3a7b55395f66d4764cd4c8801 Mon Sep 17 00:00:00 2001 From: lvliang Date: Tue, 22 Dec 2020 21:55:43 +0800 Subject: [PATCH] recovery-already-run --- mindspore/nn/cell.py | 1 + tests/st/pynative/test_pynative_resnet50_ascend.py | 3 +-- tests/st/pynative/test_pynative_resnet50_gpu.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index b51391bbe0..6c85fcdf6b 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -361,6 +361,7 @@ class Cell(Cell_): _pynative_exec.end_graph(self, output, *inputs, **kwargs) for i, cell in enumerate(self.cells()): cell.set_grad(origin_grad[i]) + self._already_run = True return output def _add_attr(self, name, value): diff --git a/tests/st/pynative/test_pynative_resnet50_ascend.py b/tests/st/pynative/test_pynative_resnet50_ascend.py index 0aa9ce3aa8..e7d32b7a9a 100644 --- a/tests/st/pynative/test_pynative_resnet50_ascend.py +++ b/tests/st/pynative/test_pynative_resnet50_ascend.py @@ -38,7 +38,6 @@ random.seed(1) np.random.seed(1) ds.config.set_seed(1) - grad_by_list = CP.GradOperation(get_by_list=True) @@ -404,10 +403,10 @@ def test_pynative_resnet50(): step = step + 1 if step > max_step: break + start_time = time.time() input_data = element["image"] input_label = element["label"] loss_output = net_with_criterion(input_data, input_label) - start_time = time.time() grads = train_network(input_data, input_label) optimizer(grads) end_time = time.time() diff --git a/tests/st/pynative/test_pynative_resnet50_gpu.py b/tests/st/pynative/test_pynative_resnet50_gpu.py index 064ee31017..9402e51853 100644 --- a/tests/st/pynative/test_pynative_resnet50_gpu.py +++ b/tests/st/pynative/test_pynative_resnet50_gpu.py @@ -403,10 +403,10 @@ def test_pynative_resnet50(): step = step + 1 if step > max_step: break + start_time = time.time() input_data = element["image"] input_label = element["label"] loss_output = net_with_criterion(input_data, input_label) - start_time = time.time() grads = train_network(input_data, input_label) optimizer(grads) end_time = time.time()