|
|
|
@@ -1144,8 +1144,14 @@ class Cell(Cell_): |
|
|
|
|
|
|
|
def recompute(self, mode=True): |
|
|
|
""" |
|
|
|
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad |
|
|
|
node and is set recomputed, we will compute it again for the grad node after the forward computation. |
|
|
|
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive |
|
|
|
set recomputed feeds into a gradient node, we will compute it again for the gradient node |
|
|
|
after the forward computation. |
|
|
|
|
|
|
|
Note: |
|
|
|
If the recompute api of a primtive in this cell is also called, the recompute mode of this |
|
|
|
primitive is subject to the recompute api of the primitive. |
|
|
|
|
|
|
|
Args: |
|
|
|
mode (bool): Specifies whether the cell is recomputed. Default: True. |
|
|
|
""" |
|
|
|
|