From 8dbb3722281ee676c005059fb913bbe35718a05e Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Tue, 9 Mar 2021 16:59:46 +0800 Subject: [PATCH] Raise error when calling recompute api in pynative mode --- mindspore/nn/cell.py | 2 ++ mindspore/ops/primitive.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a27a2bad5d..21a7f9fdda 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1157,6 +1157,8 @@ class Cell(Cell_): Args: mode (bool): Specifies whether the cell is recomputed. Default: True. """ + if context.get_context("mode") == context.PYNATIVE_MODE: + raise TypeError("Recompute is not supported in pynative mode currently.") Validator.check_bool(mode) self._set_recompute_scope(mode) for cell in self.cells(): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 73327c586a..280725b001 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -220,7 +220,7 @@ class Primitive(Primitive_): """ Whether the primitive will update the value of parameter.""" return self._update_parameter - def recompute(self, mode): + def recompute(self, mode=True): """ Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node, we will compute it again for the gradient node after the forward computation. @@ -228,6 +228,8 @@ class Primitive(Primitive_): Args: mode (bool): Specifies whether the primitive is recomputed. Default: True. """ + if context.get_context("mode") == context.PYNATIVE_MODE: + raise TypeError("Recompute is not supported in pynative mode currently.") Validator.check_bool(mode) self.add_prim_attr("recompute", mode) return self