Browse Source

!1907 fix pynative param bug

Merge pull request !1907 from flywind/fix_pynative_bug
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
095e41eff3
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindspore/nn/cell.py

+ 7
- 1
mindspore/nn/cell.py View File

@@ -202,9 +202,13 @@ class Cell:
out = self.compile_and_run(*inputs)
return out
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs)
for cell in self.cells():
orign_grad.append(cell.requires_grad)
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
if self.enable_hook:
@@ -215,6 +219,8 @@ class Cell:
output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs)
for i, cell in enumerate(self.cells()):
cell.set_grad(orign_grad[i])
self._is_run = True
return output

@@ -744,7 +750,7 @@ class Cell:
return self

def set_grad(self, mode=True):
self.add_flags_recursive(requires_grad=mode)
self.requires_grad = mode
return self

def set_train(self, mode=True):


Loading…
Cancel
Save