Browse Source

Fix continuous calls for recompute api

tags/v1.2.0-rc1
yujianfeng 4 years ago
parent
commit
7cac3d3a47
3 changed files with 10 additions and 6 deletions
  1. +1
    -3
      mindspore/ccsrc/frontend/optimizer/recompute.cc
  2. +7
    -3
      mindspore/nn/cell.py
  3. +2
    -0
      mindspore/ops/primitive.py

+ 1
- 3
mindspore/ccsrc/frontend/optimizer/recompute.cc View File

@@ -31,7 +31,6 @@ namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNoRecompute = "no_recompute";
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@@ -46,8 +45,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
return false;
}
auto full_name_with_scope = node->fullname_with_scope();
return full_name_with_scope.find(kAttrRecompute) == 0 &&
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos;
return full_name_with_scope.find(kAttrRecompute) == 0;
}

bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {


+ 7
- 3
mindspore/nn/cell.py View File

@@ -913,8 +913,11 @@ class Cell(Cell_):
"""Sets the name on the first time."""
if self._scope is None:
self._scope = name
elif self._scope == 'recomputed':
self._scope = self._scope + "_" + name
elif self._scope == 'recompute':
if name is None:
self._scope = None
elif name != 'recompute':
self._scope = self._scope + '_' + name

def _children_scope_recursive(self, parent_prefix='Default'):
"""Generates the scope of each layer of the network recursively."""
@@ -1102,10 +1105,11 @@ class Cell(Cell_):
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""
Validator.check_bool(mode)
if mode is True:
self._set_scope("recompute")
else:
self._set_scope("no_recompute")
self._set_scope(None)
for cell in self.cells():
cell.recompute(mode)



+ 2
- 0
mindspore/ops/primitive.py View File

@@ -19,6 +19,7 @@ import copy
from mindspore.common.api import _wrap_func
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from .._checkparam import Validator
from . import signature as sig


@@ -213,6 +214,7 @@ class Primitive(Primitive_):
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""
Validator.check_bool(mode)
self.add_prim_attr("recompute", mode)
return self



Loading…
Cancel
Save