diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a27a2bad5d..21a7f9fdda 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1157,6 +1157,8 @@ class Cell(Cell_): Args: mode (bool): Specifies whether the cell is recomputed. Default: True. """ + if context.get_context("mode") == context.PYNATIVE_MODE: + raise TypeError("Recompute is not supported in pynative mode currently.") Validator.check_bool(mode) self._set_recompute_scope(mode) for cell in self.cells(): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 73327c586a..280725b001 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -220,7 +220,7 @@ class Primitive(Primitive_): """ Whether the primitive will update the value of parameter.""" return self._update_parameter - def recompute(self, mode): + def recompute(self, mode=True): """ Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node, we will compute it again for the gradient node after the forward computation. @@ -228,6 +228,8 @@ class Primitive(Primitive_): Args: mode (bool): Specifies whether the primitive is recomputed. Default: True. """ + if context.get_context("mode") == context.PYNATIVE_MODE: + raise TypeError("Recompute is not supported in pynative mode currently.") Validator.check_bool(mode) self.add_prim_attr("recompute", mode) return self