| @@ -31,7 +31,6 @@ namespace opt { | |||||
| namespace { | namespace { | ||||
| constexpr auto kGradientsFlag = "Gradients"; | constexpr auto kGradientsFlag = "Gradients"; | ||||
| constexpr auto kAttrRecompute = "recompute"; | constexpr auto kAttrRecompute = "recompute"; | ||||
| constexpr auto kAttrNoRecompute = "no_recompute"; | |||||
| bool IsBpropNode(const AnfNodePtr &node) { | bool IsBpropNode(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -46,8 +45,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto full_name_with_scope = node->fullname_with_scope(); | auto full_name_with_scope = node->fullname_with_scope(); | ||||
| return full_name_with_scope.find(kAttrRecompute) == 0 && | |||||
| full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos; | |||||
| return full_name_with_scope.find(kAttrRecompute) == 0; | |||||
| } | } | ||||
| bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { | bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { | ||||
| @@ -913,8 +913,11 @@ class Cell(Cell_): | |||||
| """Sets the name on the first time.""" | """Sets the name on the first time.""" | ||||
| if self._scope is None: | if self._scope is None: | ||||
| self._scope = name | self._scope = name | ||||
| elif self._scope == 'recomputed': | |||||
| self._scope = self._scope + "_" + name | |||||
| elif self._scope == 'recompute': | |||||
| if name is None: | |||||
| self._scope = None | |||||
| elif name != 'recompute': | |||||
| self._scope = self._scope + '_' + name | |||||
| def _children_scope_recursive(self, parent_prefix='Default'): | def _children_scope_recursive(self, parent_prefix='Default'): | ||||
| """Generates the scope of each layer of the network recursively.""" | """Generates the scope of each layer of the network recursively.""" | ||||
| @@ -1102,10 +1105,11 @@ class Cell(Cell_): | |||||
| Args: | Args: | ||||
| mode (bool): Specifies whether the cell is recomputed. Default: True. | mode (bool): Specifies whether the cell is recomputed. Default: True. | ||||
| """ | """ | ||||
| Validator.check_bool(mode) | |||||
| if mode is True: | if mode is True: | ||||
| self._set_scope("recompute") | self._set_scope("recompute") | ||||
| else: | else: | ||||
| self._set_scope("no_recompute") | |||||
| self._set_scope(None) | |||||
| for cell in self.cells(): | for cell in self.cells(): | ||||
| cell.recompute(mode) | cell.recompute(mode) | ||||
| @@ -19,6 +19,7 @@ import copy | |||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore import context | from mindspore import context | ||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| from .._checkparam import Validator | |||||
| from . import signature as sig | from . import signature as sig | ||||
| @@ -213,6 +214,7 @@ class Primitive(Primitive_): | |||||
| Args: | Args: | ||||
| mode (bool): Specifies whether the primitive is recomputed. Default: True. | mode (bool): Specifies whether the primitive is recomputed. Default: True. | ||||
| """ | """ | ||||
| Validator.check_bool(mode) | |||||
| self.add_prim_attr("recompute", mode) | self.add_prim_attr("recompute", mode) | ||||
| return self | return self | ||||