| @@ -202,9 +202,13 @@ class Cell: | |||||
| out = self.compile_and_run(*inputs) | out = self.compile_and_run(*inputs) | ||||
| return out | return out | ||||
| self.init_parameters_data() | self.init_parameters_data() | ||||
| orign_grad = [] | |||||
| if self.requires_grad is True: | if self.requires_grad is True: | ||||
| _pynative_exec.set_grad_flag(True) | _pynative_exec.set_grad_flag(True) | ||||
| _pynative_exec.new_graph(self, *inputs) | _pynative_exec.new_graph(self, *inputs) | ||||
| for cell in self.cells(): | |||||
| orign_grad.append(cell.requires_grad) | |||||
| cell.set_grad(True) | |||||
| else: | else: | ||||
| _pynative_exec.set_grad_flag(False) | _pynative_exec.set_grad_flag(False) | ||||
| if self.enable_hook: | if self.enable_hook: | ||||
| @@ -215,6 +219,8 @@ class Cell: | |||||
| output = output.data | output = output.data | ||||
| if self.requires_grad is True: | if self.requires_grad is True: | ||||
| _pynative_exec.end_graph(self, output, *inputs) | _pynative_exec.end_graph(self, output, *inputs) | ||||
| for i, cell in enumerate(self.cells()): | |||||
| cell.set_grad(orign_grad[i]) | |||||
| self._is_run = True | self._is_run = True | ||||
| return output | return output | ||||
| @@ -744,7 +750,7 @@ class Cell: | |||||
| return self | return self | ||||
| def set_grad(self, mode=True): | def set_grad(self, mode=True): | ||||
| self.add_flags_recursive(requires_grad=mode) | |||||
| self.requires_grad = mode | |||||
| return self | return self | ||||
| def set_train(self, mode=True): | def set_train(self, mode=True): | ||||