|
|
|
@@ -220,7 +220,7 @@ class Primitive(Primitive_): |
|
|
|
""" Whether the primitive will update the value of parameter.""" |
|
|
|
return self._update_parameter |
|
|
|
|
|
|
|
def recompute(self, mode): |
|
|
|
def recompute(self, mode=True): |
|
|
|
""" |
|
|
|
Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node, |
|
|
|
we will compute it again for the gradient node after the forward computation. |
|
|
|
@@ -228,6 +228,8 @@ class Primitive(Primitive_): |
|
|
|
Args: |
|
|
|
mode (bool): Specifies whether the primitive is recomputed. Default: True. |
|
|
|
""" |
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE: |
|
|
|
raise TypeError("Recompute is not supported in pynative mode currently.") |
|
|
|
Validator.check_bool(mode) |
|
|
|
self.add_prim_attr("recompute", mode) |
|
|
|
return self |
|
|
|
|