Browse Source

Update recompute python api

tags/v1.2.0-rc1
yujianfeng 4 years ago
parent
commit
047e006aab
2 changed files with 11 additions and 4 deletions
  1. +8
    -2
      mindspore/nn/cell.py
  2. +3
    -2
      mindspore/ops/primitive.py

+ 8
- 2
mindspore/nn/cell.py View File

@@ -1144,8 +1144,14 @@ class Cell(Cell_):

def recompute(self, mode=True):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation.
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into a gradient node, we will compute it again for the gradient node
after the forward computation.

Note:
If the recompute api of a primtive in this cell is also called, the recompute mode of this
primitive is subject to the recompute api of the primitive.

Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""


+ 3
- 2
mindspore/ops/primitive.py View File

@@ -222,8 +222,9 @@ class Primitive(Primitive_):

def recompute(self, mode):
"""
Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed,
we will compute it again for the grad node after the forward computation.
Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node,
we will compute it again for the gradient node after the forward computation.

Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""


Loading…
Cancel
Save