|
|
@@ -202,9 +202,13 @@ class Cell: |
|
|
out = self.compile_and_run(*inputs) |
|
|
out = self.compile_and_run(*inputs) |
|
|
return out |
|
|
return out |
|
|
self.init_parameters_data() |
|
|
self.init_parameters_data() |
|
|
|
|
|
orign_grad = [] |
|
|
if self.requires_grad is True: |
|
|
if self.requires_grad is True: |
|
|
_pynative_exec.set_grad_flag(True) |
|
|
_pynative_exec.set_grad_flag(True) |
|
|
_pynative_exec.new_graph(self, *inputs) |
|
|
_pynative_exec.new_graph(self, *inputs) |
|
|
|
|
|
for cell in self.cells(): |
|
|
|
|
|
orign_grad.append(cell.requires_grad) |
|
|
|
|
|
cell.set_grad(True) |
|
|
else: |
|
|
else: |
|
|
_pynative_exec.set_grad_flag(False) |
|
|
_pynative_exec.set_grad_flag(False) |
|
|
if self.enable_hook: |
|
|
if self.enable_hook: |
|
|
@@ -215,6 +219,8 @@ class Cell: |
|
|
output = output.data |
|
|
output = output.data |
|
|
if self.requires_grad is True: |
|
|
if self.requires_grad is True: |
|
|
_pynative_exec.end_graph(self, output, *inputs) |
|
|
_pynative_exec.end_graph(self, output, *inputs) |
|
|
|
|
|
for i, cell in enumerate(self.cells()): |
|
|
|
|
|
cell.set_grad(orign_grad[i]) |
|
|
self._is_run = True |
|
|
self._is_run = True |
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
@@ -744,7 +750,7 @@ class Cell: |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def set_grad(self, mode=True): |
|
|
def set_grad(self, mode=True): |
|
|
self.add_flags_recursive(requires_grad=mode) |
|
|
|
|
|
|
|
|
self.requires_grad = mode |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def set_train(self, mode=True): |
|
|
def set_train(self, mode=True): |
|
|
|