| @@ -20,6 +20,9 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| return nullptr; | |||||
| } | |||||
| PatternNode x, y, z, xs; | PatternNode x, y, z, xs; | ||||
| PConstant one_(node, false, 1); | PConstant one_(node, false, 1); | ||||
| PConstant one_scalar_(node, false, 1, true); | PConstant one_scalar_(node, false, 1, true); | ||||
| @@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr | |||||
| } | } | ||||
| AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| return nullptr; | |||||
| } | |||||
| PatternNode x, y; | PatternNode x, y; | ||||
| PConstant zero_(node, false, 0); | PConstant zero_(node, false, 0); | ||||
| @@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Clear"; | MS_LOG(DEBUG) << "Clear"; | ||||
| grad_flag_ = false; | |||||
| top_g_ = nullptr; | top_g_ = nullptr; | ||||
| df_builder_ = nullptr; | df_builder_ = nullptr; | ||||
| curr_g_ = nullptr; | curr_g_ = nullptr; | ||||
| @@ -84,16 +84,16 @@ class Cell: | |||||
| self._backward_hook = None | self._backward_hook = None | ||||
| self.enable_hook = False | self.enable_hook = False | ||||
| self._bprop_debug = False | self._bprop_debug = False | ||||
| self._is_run = False | |||||
| self._already_run = False | |||||
| self.cell_type = None | self.cell_type = None | ||||
| @property | @property | ||||
| def is_run(self): | |||||
| return self._is_run | |||||
| def already_run(self): | |||||
| return self._already_run | |||||
| @is_run.setter | |||||
| def is_run(self, value): | |||||
| self._is_run = value | |||||
| @already_run.setter | |||||
| def already_run(self, value): | |||||
| self._already_run = value | |||||
| @property | @property | ||||
| def create_time(self): | def create_time(self): | ||||
| @@ -260,7 +260,7 @@ class Cell: | |||||
| _pynative_exec.end_graph(self, output, *inputs) | _pynative_exec.end_graph(self, output, *inputs) | ||||
| for i, cell in enumerate(self.cells()): | for i, cell in enumerate(self.cells()): | ||||
| cell.set_grad(orign_grad[i]) | cell.set_grad(orign_grad[i]) | ||||
| self._is_run = True | |||||
| self._already_run = True | |||||
| return output | return output | ||||
| def __setattr__(self, name, value): | def __setattr__(self, name, value): | ||||
| @@ -129,14 +129,14 @@ class GradOperation(GradOperation_): | |||||
| output = fn(*args) | output = fn(*args) | ||||
| _pynative_exec.end_graph(fn, output, *args) | _pynative_exec.end_graph(fn, output, *args) | ||||
| else: | else: | ||||
| if fn.is_run and not fn.requires_grad: | |||||
| if fn.already_run and not fn.requires_grad: | |||||
| raise ValueError("obj must set_grad.") | raise ValueError("obj must set_grad.") | ||||
| if not fn.is_run: | |||||
| if not fn.already_run: | |||||
| self.need_forward = True | self.need_forward = True | ||||
| print("already has forward run before grad by user") | |||||
| if self.need_forward: | if self.need_forward: | ||||
| fn.set_grad() | fn.set_grad() | ||||
| fn(*args) | fn(*args) | ||||
| fn.already_run = False | |||||
| def __call__(self, fn, weights=None): | def __call__(self, fn, weights=None): | ||||
| grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | ||||
| @@ -40,6 +40,9 @@ class TestOptLib : public UT::Common { | |||||
| void SetUp() { | void SetUp() { | ||||
| UT::InitPythonPath(); | UT::InitPythonPath(); | ||||
| parse::data_converter::ClearObjectCache(); | parse::data_converter::ClearObjectCache(); | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| ms_context->set_execution_mode(kGraphMode); | |||||
| } | } | ||||
| FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { | FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { | ||||
| equiv_node.clear(); | equiv_node.clear(); | ||||
| @@ -152,7 +152,7 @@ def test_hook(): | |||||
| assert cell_hook_done | assert cell_hook_done | ||||
| assert var_hook_done | assert var_hook_done | ||||
| assert cell_bprop_done | assert cell_bprop_done | ||||
| print(loss_output.asnumpy().shape) | |||||
| print(loss_output.asnumpy()) | |||||
| bprop_debug = False | bprop_debug = False | ||||