Browse Source

fix grad flag update issue in pynative

tags/v0.7.0-beta
kingfo 5 years ago
parent
commit
28dabf0332
6 changed files with 21 additions and 11 deletions
  1. +6
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
  2. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +7
    -7
      mindspore/nn/cell.py
  4. +3
    -3
      mindspore/ops/composite/base.py
  5. +3
    -0
      tests/ut/cpp/optimizer/lib_test.cc
  6. +1
    -1
      tests/ut/python/pynative_mode/test_hook.py

+ 6
- 0
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc View File

@@ -20,6 +20,9 @@ namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y, z, xs;
PConstant one_(node, false, 1);
PConstant one_scalar_(node, false, 1, true);
@@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}

AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y;
PConstant zero_(node, false, 0);



+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
}

MS_LOG(DEBUG) << "Clear";
grad_flag_ = false;
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;


+ 7
- 7
mindspore/nn/cell.py View File

@@ -84,16 +84,16 @@ class Cell:
self._backward_hook = None
self.enable_hook = False
self._bprop_debug = False
self._is_run = False
self._already_run = False
self.cell_type = None

@property
def is_run(self):
return self._is_run
def already_run(self):
return self._already_run

@is_run.setter
def is_run(self, value):
self._is_run = value
@already_run.setter
def already_run(self, value):
self._already_run = value

@property
def create_time(self):
@@ -260,7 +260,7 @@ class Cell:
_pynative_exec.end_graph(self, output, *inputs)
for i, cell in enumerate(self.cells()):
cell.set_grad(orign_grad[i])
self._is_run = True
self._already_run = True
return output

def __setattr__(self, name, value):


+ 3
- 3
mindspore/ops/composite/base.py View File

@@ -129,14 +129,14 @@ class GradOperation(GradOperation_):
output = fn(*args)
_pynative_exec.end_graph(fn, output, *args)
else:
if fn.is_run and not fn.requires_grad:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
if not fn.already_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
fn(*args)
fn.already_run = False

def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)


+ 3
- 0
tests/ut/cpp/optimizer/lib_test.cc View File

@@ -40,6 +40,9 @@ class TestOptLib : public UT::Common {
void SetUp() {
UT::InitPythonPath();
parse::data_converter::ClearObjectCache();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
}
FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
equiv_node.clear();


+ 1
- 1
tests/ut/python/pynative_mode/test_hook.py View File

@@ -152,7 +152,7 @@ def test_hook():
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
print(loss_output.asnumpy().shape)
print(loss_output.asnumpy())


bprop_debug = False


Loading…
Cancel
Save