From: @joylvliang Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjintags/v1.1.0
| @@ -361,6 +361,7 @@ class Cell(Cell_): | |||||
| _pynative_exec.end_graph(self, output, *inputs, **kwargs) | _pynative_exec.end_graph(self, output, *inputs, **kwargs) | ||||
| for i, cell in enumerate(self.cells()): | for i, cell in enumerate(self.cells()): | ||||
| cell.set_grad(origin_grad[i]) | cell.set_grad(origin_grad[i]) | ||||
| self._already_run = True | |||||
| return output | return output | ||||
| def _add_attr(self, name, value): | def _add_attr(self, name, value): | ||||
| @@ -38,7 +38,6 @@ random.seed(1) | |||||
| np.random.seed(1) | np.random.seed(1) | ||||
| ds.config.set_seed(1) | ds.config.set_seed(1) | ||||
| grad_by_list = CP.GradOperation(get_by_list=True) | grad_by_list = CP.GradOperation(get_by_list=True) | ||||
| @@ -404,10 +403,10 @@ def test_pynative_resnet50(): | |||||
| step = step + 1 | step = step + 1 | ||||
| if step > max_step: | if step > max_step: | ||||
| break | break | ||||
| start_time = time.time() | |||||
| input_data = element["image"] | input_data = element["image"] | ||||
| input_label = element["label"] | input_label = element["label"] | ||||
| loss_output = net_with_criterion(input_data, input_label) | loss_output = net_with_criterion(input_data, input_label) | ||||
| start_time = time.time() | |||||
| grads = train_network(input_data, input_label) | grads = train_network(input_data, input_label) | ||||
| optimizer(grads) | optimizer(grads) | ||||
| end_time = time.time() | end_time = time.time() | ||||
| @@ -403,10 +403,10 @@ def test_pynative_resnet50(): | |||||
| step = step + 1 | step = step + 1 | ||||
| if step > max_step: | if step > max_step: | ||||
| break | break | ||||
| start_time = time.time() | |||||
| input_data = element["image"] | input_data = element["image"] | ||||
| input_label = element["label"] | input_label = element["label"] | ||||
| loss_output = net_with_criterion(input_data, input_label) | loss_output = net_with_criterion(input_data, input_label) | ||||
| start_time = time.time() | |||||
| grads = train_network(input_data, input_label) | grads = train_network(input_data, input_label) | ||||
| optimizer(grads) | optimizer(grads) | ||||
| end_time = time.time() | end_time = time.time() | ||||