Browse Source

Raise error when calling recompute api in pynative mode

tags/v1.2.0-rc1
yujianfeng 4 years ago
parent
commit
8dbb372228
2 changed files with 5 additions and 1 deletions
  1. +2
    -0
      mindspore/nn/cell.py
  2. +3
    -1
      mindspore/ops/primitive.py

+ 2
- 0
mindspore/nn/cell.py View File

@@ -1157,6 +1157,8 @@ class Cell(Cell_):
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
self._set_recompute_scope(mode)
for cell in self.cells():


+ 3
- 1
mindspore/ops/primitive.py View File

@@ -220,7 +220,7 @@ class Primitive(Primitive_):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter

def recompute(self, mode):
def recompute(self, mode=True):
"""
Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node,
we will compute it again for the gradient node after the forward computation.
@@ -228,6 +228,8 @@ class Primitive(Primitive_):
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
self.add_prim_attr("recompute", mode)
return self


Loading…
Cancel
Save