Browse Source

!10356 Recovery already run

From: @joylvliang
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4cf9e2b67b
3 changed files with 3 additions and 3 deletions
  1. +1
    -0
      mindspore/nn/cell.py
  2. +1
    -2
      tests/st/pynative/test_pynative_resnet50_ascend.py
  3. +1
    -1
      tests/st/pynative/test_pynative_resnet50_gpu.py

+ 1
- 0
mindspore/nn/cell.py View File

@@ -361,6 +361,7 @@ class Cell(Cell_):
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
return output

def _add_attr(self, name, value):


+ 1
- 2
tests/st/pynative/test_pynative_resnet50_ascend.py View File

@@ -38,7 +38,6 @@ random.seed(1)
np.random.seed(1)
ds.config.set_seed(1)


grad_by_list = CP.GradOperation(get_by_list=True)


@@ -404,10 +403,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()


+ 1
- 1
tests/st/pynative/test_pynative_resnet50_gpu.py View File

@@ -403,10 +403,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()


Loading…
Cancel
Save