diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 716019f341..715b141fc4 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1144,8 +1144,14 @@ class Cell(Cell_): def recompute(self, mode=True): """ - Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad - node and is set recomputed, we will compute it again for the grad node after the forward computation. + Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive + set recomputed feeds into a gradient node, we will compute it again for the gradient node + after the forward computation. + + Note: + If the recompute api of a primtive in this cell is also called, the recompute mode of this + primitive is subject to the recompute api of the primitive. + Args: mode (bool): Specifies whether the cell is recomputed. Default: True. """ diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index ade3d1d09c..73327c586a 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -222,8 +222,9 @@ class Primitive(Primitive_): def recompute(self, mode): """ - Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed, - we will compute it again for the grad node after the forward computation. + Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node, + we will compute it again for the gradient node after the forward computation. + Args: mode (bool): Specifies whether the primitive is recomputed. Default: True. """